From 4a8b30a3dac1490a840e264241b7b057482b46e5 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 16 Aug 2024 20:44:40 +0000 Subject: [PATCH 01/34] moved userbuffers code to TE/common Signed-off-by: Alp Dener --- .gitignore | 1 + .../csrc => common/comm_gemm_overlap}/userbuffers/ipcsocket.cc | 0 .../csrc => common/comm_gemm_overlap}/userbuffers/ipcsocket.h | 0 .../comm_gemm_overlap}/userbuffers/userbuffers-host.cpp | 0 .../csrc => common/comm_gemm_overlap}/userbuffers/userbuffers.cu | 0 .../csrc => common/comm_gemm_overlap}/userbuffers/userbuffers.h | 0 6 files changed, 1 insertion(+) rename transformer_engine/{pytorch/csrc => common/comm_gemm_overlap}/userbuffers/ipcsocket.cc (100%) rename transformer_engine/{pytorch/csrc => common/comm_gemm_overlap}/userbuffers/ipcsocket.h (100%) rename transformer_engine/{pytorch/csrc => common/comm_gemm_overlap}/userbuffers/userbuffers-host.cpp (100%) rename transformer_engine/{pytorch/csrc => common/comm_gemm_overlap}/userbuffers/userbuffers.cu (100%) rename transformer_engine/{pytorch/csrc => common/comm_gemm_overlap}/userbuffers/userbuffers.h (100%) diff --git a/.gitignore b/.gitignore index 6890911c14..30f9246898 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ dist/ downloads/ .pytest_cache/ compile_commands.json +.nfs diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h From b5e18df1852aef253571385fc049d93027830158 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 23 Aug 2024 20:56:23 +0000 Subject: [PATCH 02/34] moved comm+GEMM overlap code to TE/common Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename transformer_engine/{pytorch/csrc/comm_gemm_overlap.h => common/comm_gemm_overlap/comm_gemm_overlap.cpp} (100%) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/comm_gemm_overlap.h rename to transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp From 8dfe6d63e4aa620783f006760362e03b3e5ddb42 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 26 Aug 2024 22:44:02 +0000 Subject: [PATCH 03/34] removed PyTorch depdency from comm+GEMM overlap in TE/common Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2080 +++++++---------- .../userbuffers/userbuffers-host.cpp | 16 +- .../userbuffers/userbuffers.cu | 40 + .../userbuffers/userbuffers.h | 26 +- .../transformer_engine/comm_gemm_overlap.h | 204 ++ .../transformer_engine/transformer_engine.h | 97 +- .../common/transformer_engine.cpp | 25 + 7 files changed, 1274 insertions(+), 1214 deletions(-) create mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 3b4e126943..d66d5bfa77 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -4,37 +4,29 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include +#include +#include + +#include +#include +#include #include "common/common.h" #include "common/util/cuda_driver.h" #include "common/util/logging.h" #include "common/util/system.h" -#include "extensions.h" #include "userbuffers/userbuffers.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 -using namespace torch::indexing; using namespace std::placeholders; -namespace ubuf { +namespace transformer_engine { + +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ bool device_supports_multicast() { int dev, supports_multicast; @@ -56,1248 +48,962 @@ bool ubuf_built_with_mpi() { #endif } -class UbufBootstrapCallbacks : torch::CustomClassHolder { - private: - bool initialized{false}; - bool backend_is_nccl{false}; - std::map pgs; - - public: - UbufBootstrapCallbacks() { -#ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!"); -#endif - } // empty constructor for NVTE_UB_WITH_MPI=1 - - UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) { - pgs.insert({"world", world_group}); - c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); - backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - - NVTE_CHECK(intra_node_group->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - pgs.insert({"intra", intra_node_group}); - - initialized = true; - } - - ~UbufBootstrapCallbacks() { - for (auto &pg : pgs) pg.second = nullptr; - backend_is_nccl = false; - initialized = false; - } - - void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; - auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; - std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); - work->wait(); - - if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); +CommOverlapCore::CommOverlapCore( + int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + // Initialize userbuf communicator + if (!_comm_created) { + if (myrank == 0) { + printf("!!! [UB] Create Userbuffers Communicator\n"); } - } - - void ub_barrier(char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - auto work = pgs[group]->barrier(); - work->wait(); - } -}; - -enum class COMM_TYPE { RS = 0, AG = 1 }; - -enum class UBOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 -}; - -struct UbufBase { - static inline communicator *_ub_comm{nullptr}; - static inline bool comm_created{false}; -}; -struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _num_splits; - int _math_sms; - int _ub_reg; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor output_tensor; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - torch::Tensor counter; - at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; - int _num_comm_sm; - int _cga_size; - int _use_ce; - bool _atomic_gemm; - - UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } #ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); #else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); + create_communicator_grouped2( + &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, allgather_handle, + barrier_handle, 1, 1, tp_size, 1); #endif - comm_created = true; - } - _use_ce = 0; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Allocate and register extra userbuffers - int ubuf_bytes = sample.numel() * sample.element_size(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - - if (_ub_comm->myrank == 0) { - printf("!!! [UB] Register UBuf %d\n", _ub_reg); - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } + _comm_created = true; + } + _use_ce = static_cast(use_ce); + _num_comm_sm = num_comm_sm; + _cga_size = comm_cga_size; + + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + _stream_compute.push_back(std::move(stream)); + } - _num_splits = num_splits; - _tp_size = tp_size; - _tp_id = (_ub_comm->myrank % _tp_size); - _ubuf_scale_inv_initialized = false; - - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - output_tensor = torch::Tensor(); - _atomic_gemm = atomic_gemm; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({num_splits * 2}, counter_options); - counter.index_put_({Slice(None, num_splits)}, 1); - } - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_d2dcopy, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); + _num_splits = num_splits; + _rank = _ub_comm->myrank; + _tp_size = tp_size; + _tp_id = _rank % _tp_size; + + // Set the number of SMs for GEMM with margin + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + void *counter_ptr; + size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); + NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); + _counter = TensorWrapper(counter_ptr, std::vector{(size_t)_num_splits * 2}, + DType::kInt32); } + // CUDA event creation + cudaEventCreateWithFlags(&_start_compute, 0); + cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); + cudaEventCreateWithFlags(&_stop_comm, 0); +} - ~UbufCommOverlap() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_start_d2dcopy); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); +CommOverlapCore::~CommOverlapCore() { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + if (_atomic_gemm) cudaFree(_counter.dptr()); - if (comm_created) { + for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + + if (_comm_created) { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif - comm_created = false; - } + _comm_created = false; + } } - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } +/*************************************************************************************************** + * Comm+GEMM Overlap Base (Pipelined / Collective) + **************************************************************************************************/ + +CommOverlapBase::CommOverlapBase( + const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, + int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm) + : CommOverlapCore( + myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, + barrier_handle, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, + false, atomic_gemm) { + _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, + "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", + "or 2 (multi-atomic)."); + + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_ub_comm->myrank == 0) + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); + + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); +} + +CommOverlapBase::~CommOverlapBase() { + cudaEventDestroy(_start_d2dcopy); + cudaStreamDestroy(_stream_comm); +} - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +void CommOverlapBase::bulk_overlap( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, const TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication: AG and RS + int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + if (comm_type == CommOverlapType::AG) { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } else { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } else { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } + } - // Communication: AG and RS - if (_comm_type == COMM_TYPE::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); - } else if (_comm_type == COMM_TYPE::RS) { + assert(pre_gelu_out.numel() == 0); + nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, + grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, + stream_main); + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _num_splits, false, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, + workspace.dtype()); + nvte_cublas_atomic_gemm( + A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, + workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, + _counter.data(), _stream_compute[0]); + + for (int i = 0; i < _num_splits; i++) { + if (_rs_kernel_type == 1) { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); - comm_elements *= 2; - float *scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - assert(rs_output.numel() == _ubuf.numel() / _tp_size); - assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); - assert(rs_output.element_size() == 2); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0, - comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, + _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); + reducescatter2_userbuff_strided_atomic( + rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _num_splits, &counter_ptr[i], + _ub_comm, _stream_comm); } + } else if (_rs_kernel_type == 2) { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic( + rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, _num_splits, counter_ptr, _ub_comm, + _stream_comm); + } + break; } else { - NVTE_ERROR("Not supported communication type."); + consumer(counter_ptr, i, _stream_comm); + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, + _stream_comm);); + } else { + reducescatter2_userbuff_strided( + rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm); + } } - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, - D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize, - accumulate, use_split_accumulator, _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - _ub_comm->sms = ori_sms; - - return {D, output_tensor}; - } // bulk_overlap - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options()); - torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options()); - // torch::zeros({n, m}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, - counter); + rs_output_ptr += m_chunk * rs_output.element_size(); + } - for (int i = 0; i < _num_splits; i++) { - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, - _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _num_splits, &counter_ptr[i], _ub_comm, - (cudaStream_t)_stream_comm); - } - } else if (env_p != nullptr && env_p[0] == '2') { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, - m, _num_splits, counter_ptr, _ub_comm, - (cudaStream_t)_stream_comm); - } - break; + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::split_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main) { + // Get GEMM dimensions + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t input_a_chunk_size = m_chunk * k; + size_t output_chunk_size = n * m_chunk; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + if (gemm_overlap) { + auto input_a_chunk = TensorWrapper( + A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = TensorWrapper( + _ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = TensorWrapper( + workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm( + input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, + _stream_compute[0]); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * D.element_size(); + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + input_a_chunk = TensorWrapper( + reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, A.dtype(), nullptr, nullptr, + A.scale_inv()); + output_chunk = TensorWrapper( + reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), + D.scale(), nullptr); + workspace_chunk = TensorWrapper( + reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm( + input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, + _math_sms, _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord( + _start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, + m, _ub_comm, _stream_comm);); } else { - assert(_ubuf.element_size() != 1); - consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); + reducescatter2_userbuff_stridedoutput( + rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, + _stream_comm); } rs_output_ptr += m_chunk * rs_output.element_size(); } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - - return; - } // split_overlap_rs - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output) { - // Get GEMM dimensions - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int input_a_chunk_size = m_chunk * k; - int output_chunk_size = n * m_chunk; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput( + rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, + _stream_comm); } - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - if (gemm_overlap) { - torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); + } else { + for (int i = 0; i < _num_splits; i++) { + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + auto input_a_chunk = TensorWrapper( + reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, A.dtype(), nullptr, nullptr, + A.scale_inv()); + auto output_chunk = TensorWrapper( + reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), + D.scale(), nullptr); + auto workspace_chunk = TensorWrapper( + reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm( + input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, + _math_sms, _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, + _stream_compute[i % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, + D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); + reducescatter2_userbuff_stridedoutput( + rs_output_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm); } - } else { - for (int i = 0; i < _num_splits; i++) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - return; - } // split_overlap_rs - - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - /* - ** Helper function to copy input to _ubuf - */ - void copy_input_to_ubuf(torch::Tensor input, int comm_type) { - char *ubuf_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::AG) { - if ((input.numel() * _tp_size) != _ubuf.numel() || - input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } + rs_output_ptr += m_chunk * rs_output.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } + } - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)_stream_comm)); + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::split_overlap_rs + +/*************************************************************************************************** + * Comm+GEMM Overlap P2P Base (Ring-Exchange) + **************************************************************************************************/ + +CommOverlapP2PBase::CommOverlapP2PBase( + const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, + int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, + int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) + : CommOverlapCore( + myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, + barrier_handle, tp_size, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, + use_ce, atomic_gemm) { + _is_reduce_scatter = comm_type == CommOverlapType::RS; + _aggregate = aggregate; + + // Create workspace tensor with userbuffer + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + int buffer_chunk_bytes = buffer_bytes / tp_size; + _num_ubuf_chunks = tp_size; + if (_is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); + _num_ubuf_chunks = tp_size * 2 - 1; } - torch::Tensor &get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - return output_tensor; + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_rank == 0) + printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + buffer_dtype); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); + for (int i = 0; i < _num_ubuf_chunks; i++) { + _ubufs.push_back(TensorWrapper( + reinterpret_cast(ubuf_byte_ptr), {buffer_shape[0] / tp_size, buffer_shape[1]}, + buffer_dtype)); + ubuf_byte_ptr += buffer_chunk_bytes; } - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return false; } -}; // UbufCommOverlap - -struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _ub_reg, _ub_reg2; - int _next_rank, _prev_rank, _rank, _rank_round_tp; - int _aggregate2; - int _math_sms; - int _self_chunk_id; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor counter; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - std::vector _ubufs; - at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); - at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; - int _use_ce; - int _num_comm_sm; - int _cga_size; - bool _atomic_gemm; - - UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - bool set_sm_margin, bool aggregate2, int num_max_streams, - bool is_reduce_scatter, bool atomic_gemm, bool use_ce, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; + + _self_chunk_id = _tp_id; + if (_atomic_gemm && !_is_reduce_scatter) { + _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (_use_multiatomic_ag) { + _use_ce = 0; + _ub_comm->push = 1; + if (_rank == 0) { + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); -#endif - comm_created = true; - } - _use_ce = use_ce; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Create workspace tensor with userbuffer - int ubuf_bytes = sample.numel() * sample.element_size(); - int ubuf_chunk_bytes = ubuf_bytes / tp_size; - int num_ubuf_chunks = tp_size; - if (is_reduce_scatter) { - // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk - // outputs for reduction at the end of the pipelining. - ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); - num_ubuf_chunks = static_cast(tp_size * 2 - 1); - } - - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob( - _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); - if (_ub_comm->myrank == 0) { - printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } + _self_chunk_id = 0; + NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); + } - // Create tensor chunks for easy management - char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); - for (int i = 0; i < num_ubuf_chunks; i++) { - auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, - sample.options()); - _ubufs.push_back(std::move(ubuf_chunk)); - ubuf_byte_ptr += ubuf_chunk_bytes; - } + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); +} - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } +CommOverlapP2PBase::~CommOverlapP2PBase() { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + cudaStreamDestroy(_stream_send); +} - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - _tp_size = tp_size; - _aggregate2 = aggregate2; - - _rank = _ub_comm->myrank; - _tp_id = (_rank % _tp_size); - _rank_round_tp = (_rank / _tp_size) * _tp_size; - _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; - _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; - _ubuf_scale_inv_initialized = false; - - _atomic_gemm = atomic_gemm; - _self_chunk_id = _tp_id; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({_tp_size * 2}, counter_options); - counter.index_put_({Slice(None, _tp_size)}, 1); - - if (!is_reduce_scatter) { - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (_rank == 0 && env_p != nullptr) { - if (env_p[0] == '1') { - _use_ce = 0; - _ub_comm->push = 1; - printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); - } - } - _self_chunk_id = 0; - counter.index_put_({_self_chunk_id}, 0); +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t n = _ubuf.size(0); + const size_t n_chunk = n / _tp_size; + assert(pre_gelu_out.numel() == 0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + void *D_buffer_ptr; + int D_chunk_bytes = n_chunk * m * D.element_size(); + NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, true, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, + workspace.dtype()); + + for (int i = 0; i < _tp_size - 1; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = i; + int recv_chunk_id = i + 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (_use_multiatomic_ag) { + if (i == 0) { + _ub_comm->use_ce = 0; + userbuffers_sendrecv_multiatomic( + _ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, _ub_comm, _next_rank, _prev_rank, + _tp_size, counter_ptr, true, _stream_recv); } + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, + _stream_recv); + producer(counter_ptr, recv_chunk_id, _stream_recv); + } + if (i == 0) { + nvte_cublas_atomic_gemm( + A.data(), input_b.data(), D_buffer.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, false, _counter.data(), stream_main); } - - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_send, 0); - cudaEventCreateWithFlags(&_stop_recv, 0); } - ~UbufP2PCommOverlap() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); - - if (comm_created) { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - comm_created = false; - } + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int n = _ubuf.size(0); - const int n_chunk = n / _tp_size; - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Create an GEMM output buffer with N+1 chunks in a contiguous memory - torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options()); - D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options()); - - // Get output and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - - for (int i = 0; i < _tp_size - 1; i++) { + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.dptr()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, + D_chunk_bytes, cudaMemcpyDeviceToDevice, + stream_main)); + + // Return the last N rows of D_buffer + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + cudaMemcpyDeviceToDevice, stream_main)); + + // Clean up buffer allocation + NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main)); + + _ub_comm->sms = ori_sms; +} // CommOverlapP2PBase::atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2PBase::split_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const int output_chunk_bytes = (n_chunk * m) * D.element_size(); + const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.dptr()); + char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + char *input_b_chunk_ptr = input_b_ptr + send_offset; + auto input_b_chunk = TensorWrapper( + reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), nullptr, + nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper( + reinterpret_cast(output_chunk_ptr), {n_chunk * 2, m}, D.dtype(), D.amax(), + D.scale(), nullptr); + + char *aux_chunk_ptr = (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) + : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk * 2, m} + : std::vector{0}; + auto aux_chunk = TensorWrapper( + reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = TensorWrapper( + reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm( + A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, + _math_sms, _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + _stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + } + } + } else { + for (int i = 0; i < _tp_size; i++) { // Set the userbuffer id. Buffer under send is the input for the current // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // have the AG output in all ranks to be contiguous after the ring // exchanges - int send_chunk_id = i; - int recv_chunk_id = i + 1; + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; int send_offset = comm_bytes * send_chunk_id; int recv_offset = comm_bytes * recv_chunk_id; - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == 0) { - _ub_comm->use_ce = 0; - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, - true, (cudaStream_t)_stream_recv); - } - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); - } - if (i == 0) { - te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, - D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, false, counter); - } - } - - // Store the input activation for backprop - if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); - assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA( - cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), - _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - } - - // Reset atomic counters - consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main); - - // Copy the first GEMM output chunk to the end chunk position of D_buffer - char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, - n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); - // Return the last N rows of D_buffer - _ub_comm->sms = ori_sms; - torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); - return D_return; - } // atomic_gemm_overlap_ag - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int k = (transa) ? A.size(1) : A.size(0); - const int n_chunk = _ubufs[0].size(0); - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.data_ptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - if (_aggregate2) { - const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); - - // Initial 1X input chunk exchange between neighboring peers - int send_chunk_id = _tp_id; - int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); - - int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; - const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; - const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; - - // Ring exchange of 2X inputs chunks - for (int i = 0; i < num_steps; i++) { - send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; - recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; - send_offset = comm_bytes * send_chunk_id; - recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor input_b_chunk = - torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk * 2, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < num_steps - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } else { - for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < _tp_size - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } + // GEMM + auto input_b_chunk = TensorWrapper( + _ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper( + reinterpret_cast(output_chunk_ptr), {n_chunk, m}, D.dtype(), D.amax(), D.scale(), + nullptr); + + char *aux_chunk_ptr = (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) + : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} + : std::vector{0}; + auto aux_chunk = TensorWrapper( + reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = TensorWrapper( + reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm( + A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, + _math_sms, _stream_compute[i % _stream_compute.size()]); + + if (i < _tp_size - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + _stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); } } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - _ub_comm->sms = ori_sms; - - return D; - } // split_overlap_ag - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Get input and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - // Atomic GEMM - // Process GEMM chunks in the order that AG+GEMM places the output chunks. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, - D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, - workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size, - true, counter); - - // P2P communication chunk - for (int i = 1; i < _tp_size; i++) { - int send_chunk_id = i - 1; - int recv_chunk_id = send_chunk_id + _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - - consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - (cudaStream_t)_stream_recv); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; } - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int k = A.size(1); - int n = B.size(0); - - // Get communication and GEMM input chunk sizes - int n_chunk = n / _tp_size; - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); - - // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Reset counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, false, stream_main); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = TensorWrapper( + workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm( + A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, + workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, 0, _tp_size, true, + _counter.data(), stream_main); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + _ub_comm->sms = ori_sms; +} - // GEMM and send/recv chunks - for (int i = 0; i < _tp_size; i++) { - // GEMM chunk - int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); - // Store the last GEMM chunk output to the recieve buffer. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - if (i > 0) { - // P2P communication chunk - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); - } - } - at::cuda::setCurrentCUDAStream(stream_main); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::split_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t k = A.size(1); + size_t n = B.size(0); + + // Get communication and GEMM input chunk sizes + size_t n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); } - /* - ** Copy input to _ubufs[0] - */ - void copy_input_to_ubuf(torch::Tensor input, bool chunk) { - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { - // Copy input to the target ubuf chunk by rank offset - if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + + auto input_b_chunk = TensorWrapper( + reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, B.dtype(), nullptr, nullptr, + B.scale_inv()); + + auto output_chunk = TensorWrapper( + _ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = TensorWrapper( + reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm( + A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA(cudaEventRecord( + _start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + send_rank, _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + recv_rank, _stream_recv); } } - torch::Tensor get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - } - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); } - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return true; } -}; // UbufP2PCommOverlap - -} // namespace ubuf + _ub_comm->sms = ori_sms; +} -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ +} // namespace transformer_engine diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e2628f6a31..7f6ece4eac 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -108,9 +108,8 @@ int pipe_rank(communicator *comm, int step) { int create_communicator_grouped2( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes) { + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, + int pipenodes, int tensorgpus, int tensornodes) { *comm = new communicator(); (*comm)->comm_world = EXT_COMM_WORLD; @@ -348,16 +347,15 @@ int create_communicator_grouped2( int create_communicator_grouped( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes) { + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, + int pipenodes) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); } -int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier) { +int create_communicator( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, 1, 1, 1, 1); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 0cd2a0253b..37b9053696 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2496,6 +2496,19 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i } } +// reset counters kernel +static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + #pragma unroll + for (int i = 0; i < num_chunks; i++) { + ((unsigned int *)atomic_ptr)[i] = 1; + ((unsigned int *)atomic_ptr)[i + num_chunks] = 0; + } + if (allgather) + ((unsigned int *)atomic_ptr)[0] = 0; + } +} + void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); @@ -2514,6 +2527,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); } +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); +} + template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, @@ -2546,3 +2565,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); + +__global__ void __launch_bounds__(MAX_THREADS / 4) + reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + half *inputs_half = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_half[tid]); +#pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_half[tid + input_size * i]); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half)accum_buf; +} + +void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_bf16_cuda<<>>(inputs, output, num_inputs, input_size); +} diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 371932f446..eb2d812824 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -24,6 +24,9 @@ typedef MPI_Comm ExtComm; typedef char *ExtComm; #endif +#define ExtAllgatherOp std::function +#define ExtBarrierOp std::function + #define NVTE_MAX_REGIONS 16 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 @@ -142,8 +145,8 @@ struct communicator { volatile int tail; // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) - std::function _allgather; - std::function _barrier; + ExtAllgatherOp _allgather; + ExtBarrierOp _barrier; ExtComm comm_world, comm_inter, // reduction group communicator (subset of the nodes) along GPU rail @@ -161,23 +164,22 @@ typedef struct communicator communicator; void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream); /* creates communicator, allocates all internal buffers if necessary */ int create_communicator_grouped2( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes); + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, + int pipenodes, int tensorgpus, int tensornodes); int create_communicator_grouped( communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes); + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, + int pipenodes); -int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier); +int create_communicator( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier); int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes); @@ -314,4 +316,6 @@ template void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h new file mode 100644 index 0000000000..fdf9158f57 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -0,0 +1,204 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ + +#include + +#include +#include + +#include + +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" + +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + +namespace transformer_engine { + +bool device_supports_multicast(); + +bool ubuf_built_with_mpi(); + +enum class CommOverlapType { RS = 0, AG = 1 }; + +enum class CommOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG_P2P = 2, + SPLIT_PIPELINED_RS = 3, + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 +}; + +class CommOverlapCore { + public: + static inline communicator *_ub_comm{nullptr}; + static inline bool _comm_created{false}; + + int _rank; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _num_comm_sm; + int _cga_size; + int _use_ce; + bool _atomic_gemm{false}; + bool _is_p2p{false}; + + int _ub_reg; + TensorWrapper _ubuf; + TensorWrapper _counter; + float *_ubuf_scale_inv; + bool _ubuf_scale_inv_initialized{false}; + + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + + CommOverlapCore( + int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); + + virtual ~CommOverlapCore(); + + void set_ubuf_scale_inv(float *scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_atomic_gemm() { return _atomic_gemm; } + + bool is_p2p_overlap() { return _is_p2p; } + + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } +}; // CommOverlapCore + +class CommOverlapBase : public CommOverlapCore { + public: + int _rs_kernel_type; + cudaStream_t _stream_comm; + cudaEvent_t _start_d2dcopy; + + CommOverlapBase( + const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, + int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + ~CommOverlapBase(); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + void bulk_overlap( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, const TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main); +}; // CommOverlapBase + +class CommOverlapP2PBase : public CommOverlapCore { + public: + bool _is_reduce_scatter{false}; + bool _use_multiatomic_ag{false}; + + int _next_rank; + int _prev_rank; + int _rank_round_tp; + int _aggregate; + int _num_ubuf_chunks; + int _self_chunk_id; + + std::vector _ubufs; + + cudaStream_t _stream_send; + cudaStream_t _stream_recv; + cudaEvent_t _stop_send, _stop_recv; + + CommOverlapP2PBase( + const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, + int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); + + ~CommOverlapP2PBase(); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void split_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main); +}; // CommOverlapP2PBase + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 191fc40ead..d302518235 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a tensor's data type. +/*! \brief Get a raw pointer to the tensor's data. * * \param[in] tensor Tensor. * - * \return A data type of the input tensor. + * \return A raw pointer to tensor's data. */ -NVTEDType nvte_tensor_type(const NVTETensor tensor); +void *nvte_tensor_data(const NVTETensor tensor); /*! \brief Get a tensor's data shape. * @@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return Number of tensor dimensions. */ -void *nvte_tensor_data(const NVTETensor tensor); +size_t nvte_tensor_ndims(const NVTETensor tensor); + +/*! \brief Get the size of a specific tensor dimension. + * + * \param[in] tensor Tensor. + * \param[in] size_t Dimension index. + * + * \return Size of the tensor at the specified dimension. + */ +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); + +/*! \brief Get a tensor's total number of elements. + * + * \param[in] tensor Tensor. + * + * \return Number of elements in the tensor. + */ +size_t nvte_tensor_numel(const NVTETensor tensor); + +/*! \brief Get the byte size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor's data type. + */ +size_t nvte_tensor_element_size(const NVTETensor tensor); + +/*! \brief Get a tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return A data type of the input tensor. + */ +NVTEDType nvte_tensor_type(const NVTETensor tensor); /*! \brief Get a pointer to the tensor's amax data. * @@ -265,6 +298,56 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the size of this TensorWrapper in the given dimension. + * + * \param[in] size_t Dimension index. + * + * \return Size of this TensorWrapper in given dimension. + */ + size_t size(const size_t dim) const { + if (tensor_ == nullptr) return 0; + return nvte_tensor_size(tensor_, dim); + } + + /*! \brief Get the number of dimensions for this TensorWrapper. + * + * \return Number of dimensions for this TensorWrapper. + */ + size_t ndim() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_ndims(tensor_); + } + + /*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors + * with nullptr data even if the TensorWrapper has a non-zero shape. + * + * + * \return Number of elements in the tensor. + */ + size_t numel() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_); + } + + /*! \brief Get the tensor's element size in bytes. + * + * \return Element size in bytes. + */ + size_t element_size() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size(tensor_); + } + + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr + * data even if the TensorWrapper has a non-zero shape and valid dtype. + * + * \return Total tensor size in bytes. + */ + size_t bytes() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + } + /*! \brief Get the data type of this TensorWrapper. * * \return Data type of this TensorWrapper. @@ -317,6 +400,6 @@ class TensorWrapper { } // namespace transformer_engine -#endif +#endif // __cplusplus #endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5cfab2f8cf..1a3b49f9fa 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { return ret; } +size_t nvte_tensor_ndim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.data.shape.size(); +} + +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); + return t.data.shape[dim]; +} + +size_t nvte_tensor_numel(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + size_t numel = 1; + for (auto size : t.data.shape) { + numel *= size; + } + return numel; +} + +size_t nvte_tensor_element_size(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return transformer_engine::typeToSize(t.data.dtype); +} + void *nvte_tensor_data(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.data.dptr; From f4886207c6710c7b0db8bc8eec116dcda8e18928 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 26 Aug 2024 23:07:26 +0000 Subject: [PATCH 04/34] added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common Signed-off-by: Alp Dener --- build_tools/pytorch.py | 19 - setup.py | 9 +- transformer_engine/common/CMakeLists.txt | 14 +- .../common/util/pybind_helper.h | 78 ++++ transformer_engine/pytorch/csrc/common.h | 5 +- transformer_engine/pytorch/csrc/extensions.h | 156 +++++++ .../csrc/extensions/comm_gemm_overlap.cpp | 420 ++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 185 +++----- 8 files changed, 746 insertions(+), 140 deletions(-) create mode 100644 transformer_engine/common/util/pybind_helper.h create mode 100644 transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 9152229d2f..ba1827c6bb 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -29,9 +29,6 @@ def setup_pytorch_extension( sources = [ csrc_source_files / "common.cu", csrc_source_files / "ts_fp8_op.cpp", - csrc_source_files / "userbuffers" / "ipcsocket.cc", - csrc_source_files / "userbuffers" / "userbuffers.cu", - csrc_source_files / "userbuffers" / "userbuffers-host.cpp", ] + all_files_in_dir(extensions_dir) # Header files @@ -85,20 +82,6 @@ def setup_pytorch_extension( continue # Already handled nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) - # Libraries - library_dirs = [] - libraries = [] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - mpi_home = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_home / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") - nvcc_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs.append(mpi_home / "lib") - libraries.append("mpi") - # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] @@ -112,6 +95,4 @@ def setup_pytorch_extension( "cxx": cxx_flags, "nvcc": nvcc_flags, }, - libraries=[str(lib) for lib in libraries], - library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/setup.py b/setup.py index 512defa619..fbf7a37723 100644 --- a/setup.py +++ b/setup.py @@ -57,13 +57,20 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + if os.getenv("NVTE_UB_WITH_MPI"): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + # Project directory root root_path = Path(__file__).resolve().parent return CMakeExtension( name="transformer_engine", cmake_path=root_path / Path("transformer_engine/common"), - cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], + cmake_flags=cmake_flags, ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cabb2e2aea..4e01dbd710 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu - recipe/delayed_scaling.cu) + recipe/delayed_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -93,6 +97,14 @@ target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h new file mode 100644 index 0000000000..7f903089e6 --- /dev/null +++ b/transformer_engine/common/util/pybind_helper.h @@ -0,0 +1,78 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ + +#include + +#include +#include +#include + +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout", py::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType", py::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", py::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::device_supports_multicast, \ + py::call_guard()); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); + +#endif diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 04a1193a71..243306425f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -21,9 +22,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -36,7 +39,7 @@ #include #include -#include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c30e583178..8306add5ee 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -504,4 +504,160 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +class CommOverlapHelper : torch::CustomClassHolder { + private: + bool initialized{false}; + bool backend_is_nccl{false}; + std::map pgs; + + public: + CommOverlapHelper(); + + CommOverlapHelper(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group); + + ~CommOverlapHelper(); + + void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + char *group); + + void ub_barrier(char *group); +}; + +class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + public: + CommOverlap( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + CommOverlapHelper *callbacks, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, int comm_type); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + std::vector bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output); +}; // CommOverlap + +class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + public: + CommOverlapP2P( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + CommOverlapHelper *callbacks, transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int num_comm_sm = 3, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, bool chunk); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void atomic_gemm_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void split_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output); +}; // CommOverlapP2P +>>>>>>> 7f2dcc5 (added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common) + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..fa257bce62 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,420 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + #include "../extensions.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace torch::indexing; +using namespace std::placeholders; + +namespace te = transformer_engine; + +#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, \ + B, B_scale_inv, B_fp8_index, B_type, \ + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, \ + workspace) \ + A = A.contiguous(); \ + void *A_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(A_type)) { \ + assert(A_scale_inv.numel()); \ + A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ + } \ + auto A_ = makeTransformerEngineTensor( \ + A.data_ptr(), std::vector{(size_t)A.size(0), (size_t)A.size(1)}, A_type, nullptr, \ + nullptr, A_scale_inv_ptr); \ + B = B.contiguous(); \ + void *B_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(B_type)) { \ + assert(B_scale_inv.numel()); \ + B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ + } \ + auto B_ = makeTransformerEngineTensor( \ + B.data_ptr(), std::vector{(size_t)B.size(0), (size_t)B.size(1)}, B_type, nullptr, \ + nullptr, B_scale_inv_ptr); \ + void *D_amax_ptr = nullptr; \ + void *D_scale_ptr = nullptr; \ + if (te::is_fp8_dtype(D_type)) { \ + assert(D_amax.numel()); \ + D_amax_ptr = D_amax.data_ptr(); \ + assert(D_scale.numel()); \ + D_scale_ptr = D_scale.data_ptr(); \ + } \ + auto D_ = makeTransformerEngineTensor( \ + D.data_ptr(), std::vector{(size_t)D.size(0), (size_t)D.size(1)}, D_type, \ + D_amax_ptr, D_scale_ptr, nullptr); \ + auto bias_ = makeTransformerEngineTensor( \ + bias.data_ptr(), std::vector{(size_t)bias.size(0)}, bias_type); \ + const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ + ? std::vector{static_cast(pre_gelu_out.size(0))} \ + : std::vector{static_cast(pre_gelu_out.size(0)), \ + static_cast(pre_gelu_out.size(1))}; \ + auto pre_gelu_out_ = makeTransformerEngineTensor( \ + pre_gelu_out.data_ptr(), gelu_shape, \ + GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ + auto workspace_ = makeTransformerEngineTensor( \ + workspace.data_ptr(), std::vector{(size_t)workspace.size(0)}, te::DType::kByte); + +/*************************************************************************************************** + * CommOverlapHelper + **************************************************************************************************/ + +CommOverlapHelper::CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); +#endif +} // empty constructor for NVTE_UB_WITH_MPI=1 + +CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, + c10d::ProcessGroup *intra_node_group) { +pgs.insert({"world", world_group}); +c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); +backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + +NVTE_CHECK(intra_node_group->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); +pgs.insert({"intra", intra_node_group}); + +initialized = true; +} + +CommOverlapHelper::~CommOverlapHelper() { +for (auto &pg : pgs) pg.second = nullptr; +backend_is_nccl = false; +initialized = false; +} + +void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, + size_t localbytes, char *group) { +NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + +auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); +auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; +auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); +auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + +std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; +std::vector localchunk = {localtmp}; +auto work = pgs[group]->allgather(globalchunks, localchunk); +work->wait(); + +if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); +} +} + +void CommOverlapHelper::ub_barrier(char *group) { +NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); +auto work = pgs[group]->barrier(); +work->wait(); +} + +/*************************************************************************************************** + * CommOverlap + **************************************************************************************************/ + +CommOverlap::CommOverlap( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + CommOverlapHelper *callbacks, int num_splits, int num_max_streams, int comm_cga_size, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), + myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, callbacks, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, callbacks, _1), num_splits, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) {} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +std::vector CommOverlap::bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + te::CommOverlapType comm_type, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::bulk_overlap( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, comm_type, rs_out_, stream_main); + + // Get the current userbuf offset + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == te::CommOverlapType::RS) { + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } + + // Generate output tensor from userbuf data pointer + int output_c_dim0 = (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) + : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + auto output_tensor = torch::from_blob( + ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + + return {D, output_tensor}; +} // CommOverlap::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::atomic_gemm_overlap_rs( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::split_overlap_rs( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Helper function to copy input to _ubuf +*/ +void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type == te::CommOverlapType::AG) { + if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + } + + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)_stream_comm)); +} + +torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + int output_c_dim0 = (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) + : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob( + ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); +} + +/*************************************************************************************************** + * CommOverlapP2P + **************************************************************************************************/ + +CommOverlapP2P::CommOverlapP2P( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + CommOverlapHelper *callbacks, transformer_engine::CommOverlapType comm_type, + int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool use_ce, bool aggregate) + : te::CommOverlapP2PBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), + myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, callbacks, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, callbacks, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::atomic_gemm_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_ag( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, B_copy_, stream_main); +} // atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::split_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_ag( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, B_copy_, stream_main); +} // split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_rs( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, rs_out_, stream_main); +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, + B, B_scale_inverse, B_fp8_tensor, B_type, + D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, + workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_rs( + A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, + use_split_accumulator, rs_out_, stream_main); +} + +/* +** Copy input to _ubufs[0] +*/ +void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + if (chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != (int64_t)_ubufs[0].numel() || + input.element_size() != (int64_t)_ubufs[0].element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } +} + +torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + int output_c_dim0 = (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) + : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob( + ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7bd5a2d8c8..55dc52624d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,12 +4,16 @@ * See LICENSE for license information. ************************************************************************/ -#include +#include +#include + +#include "common/util/pybind_helper.h" -#include "../comm_gemm_overlap.h" #include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_bwd", moe_permute_bwd); @@ -220,150 +224,95 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Data structures - py::class_(m, "FP8TensorMeta") + py::class_(m, "FP8TensorMeta", py::module_local()) .def(py::init<>()) .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - m.def("device_supports_multicast", &ubuf::device_supports_multicast, - py::call_guard()); + py::enum_(m, "FP8FwdTensors", py::module_local()) + .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); - m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi, - py::call_guard()); + py::enum_(m, "FP8BwdTensors", py::module_local()) + .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - py::class_(m, "UbufBootstrapCallbacks") + py::class_(m, "CommOverlapHelper", py::module_local()) .def(py::init<>(), py::call_guard()) .def(py::init(), - py::call_guard()); - - py::enum_(m, "UbufOverlapAlgo") - .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) - .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) - .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) - .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) - .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) - .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); + py::call_guard(), + py::arg("world_group"), py::arg("intra_node_group")); - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufCommOverlap") - .def(py::init(), + py::class_(m, "CommOverlap", py::module_local()) + .def(py::init &, at::ScalarType, int, int, int, int, int, int, int, + CommOverlapHelper *, int, int, int, int, bool, bool>(), + py::call_guard(), + py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("myrank"), py::arg("numranks"), + py::arg("mylocal"), py::arg("numlocal"), py::arg("mynode"), py::arg("numnodes"), + py::arg("tp_size"), py::arg("callbacks"), py::arg("num_splits") = 3, + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false) + .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) - .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap, + .def("split_overlap_rs", &CommOverlap::split_overlap_rs, py::call_guard()) - .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs, + .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv, + .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, py::call_guard()) - .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs, + .def("get_ubuf_output", &CommOverlap::get_ubuf_output, py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf, + .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf, + .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output, + .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap, + .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufP2PCommOverlap") - .def(py::init(), - py::call_guard()) - .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag, + py::class_(m, "CommOverlapP2P", py::module_local()) + .def(py::init &, at::ScalarType, int, int, int, int, int, int, int, + CommOverlapHelper *, transformer_engine::CommOverlapType, int, int, int, + bool, bool, bool, bool>(), + py::call_guard(), + py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("myrank"), py::arg("numranks"), + py::arg("mylocal"), py::arg("numlocal"), py::arg("mynode"), py::arg("numnodes"), + py::arg("tp_size"), py::arg("callbacks"), py::arg("comm_type"), + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, py::call_guard()) - .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs, + .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag, + .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs, + .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf, + .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output, + .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf, + .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm, + .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap, + .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv, + .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, py::call_guard()); - - py::enum_(m, "DType", py::module_local()) - .value("kByte", transformer_engine::DType::kByte) - .value("kInt32", transformer_engine::DType::kInt32) - .value("kFloat32", transformer_engine::DType::kFloat32) - .value("kFloat16", transformer_engine::DType::kFloat16) - .value("kBFloat16", transformer_engine::DType::kBFloat16) - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); - - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); - - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - py::enum_(m, "NVTE_Fused_Attn_Backend") - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); } From a36ebf61e643d6d2e304777e0a5788ee98b2001f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 26 Aug 2024 23:08:56 +0000 Subject: [PATCH 05/34] updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code Signed-off-by: Alp Dener --- .../distributed/run_gemm_with_overlap.py | 144 ++++++++---------- .../distributed/test_comm_gemm_overlap.py | 3 + .../pytorch/cpp_extensions/gemm.py | 47 +++--- transformer_engine/pytorch/module/base.py | 50 +++--- 4 files changed, 112 insertions(+), 132 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 5ba70ccbdd..d8a1c548f8 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -32,8 +32,8 @@ } nvte_comm_types = { - "rs": 0, - "ag": 1, + "rs": tex.CommOverlapType.RS, + "ag": tex.CommOverlapType.AG, } @@ -75,7 +75,7 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--comm-type", type=partial(_mapped_argtype, typemap=nvte_comm_types), - default=0, + default=tex.CommOverlapType.AG, help="Comm type to overlap.", ) parser.add_argument( @@ -156,12 +156,10 @@ def _parse_args(argv=None, namespace=None): if opts.fp8: warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") opts.fp8 = False - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.atomic: setattr(opts, "atomic_rs_p2p", opts.p2p) - if not opts.p2p: - warnings.warn("All-gather overlap is only supported with point-2-point comms.") - opts.p2p = True + opts.p2p = True if opts.atomic: if not te.fp8.check_fp8_support(): @@ -284,34 +282,34 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None print("\n", end="", flush=True) ub_callbacks = ( - tex.UbufBootstrapCallbacks() + tex.CommOverlapHelper() if tex.ubuf_built_with_mpi() - else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) + else tex.CommOverlapHelper(bootstrap_pg, bootstrap_pg) ) - if opts.comm_type == 0: + if opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS elif opts.p2p: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P ) else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + tex.CommOverlapAlgo.ATOMIC_GEMM_RS if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ) - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ) else: raise TypeError("Invalid comm+GEMM overlap type!") @@ -322,13 +320,15 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None hidden_size = opts.num_heads * opts.head_dim inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) - ubuf_dtype = torch.bfloat16 - if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): - ubuf_dtype = torch.uint8 - sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") - ub_obj = ub_obj = ( - tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer + buffer_dtype = torch.bfloat16 + if (opts.fp8 + and not opts.bulk_overlap + and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output)): + buffer_dtype = torch.uint8 + ub_obj = ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + buffer_dtype, WORLD_RANK, # World rank WORLD_SIZE, # World size LOCAL_RANK, # Rank within the node @@ -336,19 +336,17 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None 0, # Node ID 1, # Number of nodes tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - opts.comm_type == 0 or opts.atomic, # Set SM margin - opts.aggregate, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - opts.comm_type == 0, # overlap with reduce scatter - opts.atomic, # use a single GEMM with atomic-counters - not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ub_callbacks, + opts.comm_type, + set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, + atomic_gemm=opts.atomic, + aggregate=opts.aggregate, + use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ) if opts.p2p - else tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer + else tex.CommOverlap( + (outer_size, hidden_size), + buffer_dtype, WORLD_RANK, # World rank WORLD_SIZE, # World size LOCAL_RANK, # Rank within the node @@ -356,27 +354,18 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None 0, # Node ID 1, # Number of nodes tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - opts.atomic, # Use a single GEMM with atomic-counters ub_callbacks, + atomic_gemm=opts.atomic, ) ) # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None - if opts.atomic and opts.comm_type == 1 and opts.check_numerics: - sample_buffer2 = torch.empty( - (outer_size, hidden_size), - dtype=torch.uint8 if opts.fp8_output else torch.bfloat16, - device="cuda", - ) + if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: ub_obj2 = ( - tex.UbufP2PCommOverlap( - sample_buffer2, # Sample userbuffer + tex.CommOverlapP2P( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, WORLD_RANK, # World rank WORLD_SIZE, # World size LOCAL_RANK, # Rank within the node @@ -384,19 +373,15 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None 0, # Node ID 1, # Number of nodes tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - True, # Set SM margin - False, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - True, # overlap with reduce scatter - True, # use a single GEMM with atomic-counters - True, # use copy engine for P2P communications ub_callbacks, + tex.CommOverlapType.RS, + set_sm_margin=True, + atomic_gemm=True, ) if opts.atomic_rs_p2p - else tex.UbufCommOverlap( - sample_buffer2, # Sample userbuffer + else tex.CommOverlap( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, WORLD_RANK, # World rank WORLD_SIZE, # World size LOCAL_RANK, # Rank within the node @@ -404,13 +389,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None 0, # Node ID 1, # Number of nodes tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - True, # uUe a single GEMM with atomic-counters ub_callbacks, + atomic_gemm=True, ) ) @@ -426,12 +406,12 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None local_kernel_t_shape = (ffn_hidden_size, hidden_size) local_inp_shape = (outer_size, hidden_size) # Bulk overlap comm tensor is distributed for AG overlap only - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: bulk_inp_shape = (outer_size // tp_size, hidden_size) else: bulk_inp_shape = (outer_size, hidden_size) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) local_inp_shape = (outer_size // tp_size, hidden_size) @@ -472,7 +452,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) ker_g = torch.transpose( te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 @@ -494,7 +474,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ).to(dtype=torch.float32) if opts.bulk_overlap: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] else: # First all-gather all the bulk inputs into a list @@ -529,7 +509,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) elif ub_obj2 is not None: @@ -551,7 +531,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None kernel_t_fp8 = tex.cast_to_fp8( kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_inp_fp8 = tex.cast_to_fp8( bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype ) @@ -574,7 +554,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None rtol=0.125, atol=0.0675, ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], @@ -590,7 +570,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ) # Set Fp8 scales for userbuffers - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) if ub_obj2 is not None: ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) @@ -602,7 +582,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Set up comm/compute buffers ubuf_out2 = None rs_out2 = None - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: ub_obj.copy_input_to_ubuf(bulk_inp, 1) gemm_inp = inp @@ -686,9 +666,9 @@ def _fp8_gemm2(gemm1_out): gelu=False, use_split_accumulator=te.module.base._2X_ACC_FPROP, ub_algo=( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic_rs_p2p - else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else tex.CommOverlapAlgo.ATOMIC_GEMM_RS ), ub=ub_obj2, extra_output_tensor=rs_out2, @@ -762,10 +742,10 @@ def _gemm(): avg_gpu_time = sum(gpu_times) / opts.timing_iters gemm_name = "".join( [ - "p2p all-gather + " if opts.comm_type == 1 else "", + "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "", "atomic " if opts.atomic else "", "GEMM", - (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), + (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == tex.CommOverlapType.RS else ""), ] ) timing_info = ( @@ -781,7 +761,7 @@ def _gemm(): dist.barrier(tp_group) if opts.bulk_overlap: output_info = "" - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered test_out = ub_obj.get_ubuf_output(1) else: @@ -794,7 +774,7 @@ def _gemm(): output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" dist_print( output_info, - src=0 if opts.comm_type == 0 else None, + src=0 if opts.comm_type == tex.CommOverlapType.RS else None, section=True, ) @@ -805,7 +785,7 @@ def _gemm(): ) dist_print(nonzero_info, src=0, section=True, group=tp_group) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if ub_obj2 is not None: # AG+RS Output: (M/P, N) -> gather -> (M, N) output = rs_out2.to(dtype=torch.float32) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 63310195ae..ce46a72189 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -42,6 +42,9 @@ # Force GPU kernels to launch in the order they're executed by the host CPU os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +# Clear torch.dynamo caches +torch._dynamo.reset() + def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): test_path = TEST_ROOT / "run_gemm_with_overlap.py" diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fd1eb4a810..932bb3cafa 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -45,8 +45,8 @@ def fp8_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, - ub_algo: tex.UbufOverlapAlgo = None, - ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" @@ -107,7 +107,7 @@ def fp8_gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -115,11 +115,11 @@ def fp8_gemm( args = tuple( args + ( - 1, + tex.CommOverlapType.AG, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -127,23 +127,23 @@ def fp8_gemm( args = tuple( args + ( - 0, + tex.CommOverlapType.RS, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: fn = ub.atomic_gemm_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -155,13 +155,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" args = tuple( @@ -171,16 +171,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: fn = ub.atomic_gemm_overlap_rs_p2p assert ( extra_output_tensor is not None ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: - out = fn(*args) - else: - _ = fn(*args) + _ = fn(*args) return out, gelu_input @@ -198,8 +195,8 @@ def gemm( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, - ub_algo: tex.UbufOverlapAlgo = None, - ub: tex.UbufCommOverlap = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Non FP8 GEMM.""" @@ -270,19 +267,19 @@ def gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap - args = tuple(args + (1, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap - args = tuple(args + (0, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -294,7 +291,7 @@ def gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 12ce5f0877..6bb398d078 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -109,7 +109,7 @@ def initialize_ub( local_size = tp_size self_node_idx = world_rank // tp_size num_nodes = world_size // tp_size - ub_callbacks = tex.UbufBootstrapCallbacks() + ub_callbacks = tex.CommOverlapHelper() else: assert ( torch.distributed.is_initialized() @@ -203,7 +203,7 @@ def initialize_ub( flush=True, ) - ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group) + ub_callbacks = tex.CommOverlapHelper(world_group, intra_node_group) # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace @@ -303,12 +303,11 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - sample_buffer = torch.empty( - shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda" - ) + dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype if method == "ring_exchange": - ub_obj = tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + dtype, # Communication buffer data type world_rank, # World rank world_size, # World size local_rank, # Rank within the node @@ -316,19 +315,20 @@ def add_ub( self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - set_sm_margin, # Set SM margin - aggregate, # Aggregate 2X GEMM chunks - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - is_reduce_scatter, # Overlap with reduce scatter - atomic_gemm, # Use a single GEMM with atomic-counters - use_ce, # Use copy engine for P2P communications - ub_callbacks, + ub_callbacks, # Helper for torch.distributed callbacks during bootstrapping + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_steams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, ) else: - ub_obj = tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + dtype, # Communication buffer data type world_rank, # World rank world_size, # World size local_rank, # Rank within the node @@ -336,13 +336,13 @@ def add_ub( self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - num_splits, # Number of communication splits - set_sm_margin, # Set SM margin - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - atomic_gemm, # Use a single GEMM with atomic-counters - ub_callbacks, + ub_callbacks, # Helper for torch.distributed callbacks during bootstrapping + num_splits=num_splits, + num_max_steams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, ) _ub_communicators[name] = ub_obj From 2d495bcdebe4206e714132e956818576790f72b8 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 27 Aug 2024 18:36:54 +0000 Subject: [PATCH 06/34] updated unit tests to work with refactored comm+GEMM overlap code Signed-off-by: Alp Dener --- .../distributed/run_gemm_with_overlap.py | 36 ++------- .../distributed/run_layer_with_overlap.py | 1 - .../common/util/pybind_helper.h | 10 +-- transformer_engine/pytorch/csrc/extensions.h | 22 +++--- .../csrc/extensions/comm_gemm_overlap.cpp | 73 +++++++++++++------ .../pytorch/csrc/extensions/pybind.cpp | 48 ++++++------ transformer_engine/pytorch/module/base.py | 67 ++++++++--------- .../pytorch/module/layernorm_linear.py | 22 +++--- .../pytorch/module/layernorm_mlp.py | 46 ++++++------ transformer_engine/pytorch/module/linear.py | 18 ++--- 10 files changed, 169 insertions(+), 174 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index d8a1c548f8..f742242386 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -281,10 +281,10 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if WORLD_RANK == 0: print("\n", end="", flush=True) - ub_callbacks = ( + helper = ( tex.CommOverlapHelper() if tex.ubuf_built_with_mpi() - else tex.CommOverlapHelper(bootstrap_pg, bootstrap_pg) + else tex.CommOverlapHelper(bootstrap_pg) ) if opts.comm_type == tex.CommOverlapType.RS: @@ -329,14 +329,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None tex.CommOverlapP2P( (outer_size, hidden_size), buffer_dtype, - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - ub_callbacks, opts.comm_type, set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, atomic_gemm=opts.atomic, @@ -347,14 +341,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlap( (outer_size, hidden_size), buffer_dtype, - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - ub_callbacks, atomic_gemm=opts.atomic, ) ) @@ -366,14 +354,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None tex.CommOverlapP2P( (outer_size, hidden_size), torch.uint8 if opts.fp8_output else torch.bfloat16, - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - ub_callbacks, tex.CommOverlapType.RS, set_sm_margin=True, atomic_gemm=True, @@ -382,14 +364,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlap( (outer_size, hidden_size), torch.uint8 if opts.fp8_output else torch.bfloat16, - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - ub_callbacks, atomic_gemm=True, ) ) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e5653bda01..e32a7ccb12 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,7 +9,6 @@ import socket import argparse import warnings -from functools import partial import torch import torch.distributed as dist diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 7f903089e6..5d3f02033d 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -14,7 +14,7 @@ #include #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType", pybind11::module_local()) \ + pybind11::enum_(m, "DType") \ .value("kByte", transformer_engine::DType::kByte) \ .value("kInt32", transformer_engine::DType::kInt32) \ .value("kFloat32", transformer_engine::DType::kFloat32) \ @@ -35,7 +35,7 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout", py::module_local()) \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ @@ -51,15 +51,15 @@ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType", py::module_local()) \ + pybind11::enum_(m, "CommOverlapType") \ .value("RS", transformer_engine::CommOverlapType::RS) \ .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo", py::module_local()) \ + pybind11::enum_(m, "CommOverlapAlgo") \ .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ .value("SPLIT_PIPELINED_AG_P2P", \ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8306add5ee..42b576389a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include + #include "common.h" #include "common/common.h" @@ -515,9 +517,13 @@ class CommOverlapHelper : torch::CustomClassHolder { std::map pgs; public: + int myrank, numranks, mylocal, numlocal, mynode, numnodes; + CommOverlapHelper(); - CommOverlapHelper(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group); + CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_node_group_holder, + std::optional inter_node_group_holder); ~CommOverlapHelper(); @@ -531,8 +537,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOv public: CommOverlap( const std::vector &buffer_shape, at::ScalarType buffer_dtype, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - CommOverlapHelper *callbacks, int num_splits = 3, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); @@ -588,12 +593,11 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOv class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { public: CommOverlapP2P( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - CommOverlapHelper *callbacks, transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int num_comm_sm = 3, - bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); void set_ubuf_scale_inv(torch::Tensor scale_inv) { assert(scale_inv.numel()); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index fa257bce62..9e08265e0b 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -70,15 +70,43 @@ CommOverlapHelper::CommOverlapHelper() { } // empty constructor for NVTE_UB_WITH_MPI=1 CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - c10d::ProcessGroup *intra_node_group) { + std::optional intra_node_group_holder, + std::optional inter_node_group_holder) { +myrank = world_group->getRank(); +numranks = world_group->getSize(); pgs.insert({"world", world_group}); c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); -NVTE_CHECK(intra_node_group->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); -pgs.insert({"intra", intra_node_group}); +if (intra_node_group_holder.has_value()) { + NVTE_CHECK(inter_node_group_holder.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group exists!"); + + // Get local rank on node and number of local ranks + c10d::ProcessGroup *intra_node_group = inter_node_group_holder.value(); + NVTE_CHECK(intra_node_group->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + mylocal = intra_node_group->getRank(); + numlocal = intra_node_group->getSize(); + pgs.insert({"intra", intra_node_group}); + + // Get node ID and number of nodes + c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); + NVTE_CHECK(inter_node_group->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + mynode = inter_node_group->getRank(); + numnodes = inter_node_group->getSize(); +} else { + // There is only one node so local rank/size is equal to global rank/size + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); + + mynode = 0; + numnodes = 1; +} initialized = true; } @@ -127,15 +155,14 @@ work->wait(); **************************************************************************************************/ CommOverlap::CommOverlap( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, int myrank, - int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - CommOverlapHelper *callbacks, int num_splits, int num_max_streams, int comm_cga_size, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm) + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm) : te::CommOverlapBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), - myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, callbacks, _1), num_splits, num_max_streams, + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) {} /* @@ -271,17 +298,17 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { **************************************************************************************************/ CommOverlapP2P::CommOverlapP2P( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, int myrank, - int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - CommOverlapHelper *callbacks, transformer_engine::CommOverlapType comm_type, - int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, - bool use_ce, bool aggregate) + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), - myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, callbacks, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, + num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, + aggregate) {} /* ** Split AllGather + AtomicGEMM using P2P communication diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 55dc52624d..5e7d173d19 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -224,13 +224,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Data structures - py::class_(m, "FP8TensorMeta", py::module_local()) + py::class_(m, "FP8TensorMeta") .def(py::init<>()) .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - py::enum_(m, "FP8FwdTensors", py::module_local()) + py::enum_(m, "FP8FwdTensors") .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) @@ -241,7 +241,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); - py::enum_(m, "FP8BwdTensors", py::module_local()) + py::enum_(m, "FP8BwdTensors") .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) @@ -249,22 +249,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - py::class_(m, "CommOverlapHelper", py::module_local()) + py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) - .def(py::init(), + .def(py::init, + std::optional>(), py::call_guard(), - py::arg("world_group"), py::arg("intra_node_group")); + py::arg("world_group"), py::arg("intra_node_group") = py::none(), + py::arg("inter_node_group") = py::none()); - py::class_(m, "CommOverlap", py::module_local()) - .def(py::init &, at::ScalarType, int, int, int, int, int, int, int, - CommOverlapHelper *, int, int, int, int, bool, bool>(), + py::class_(m, "CommOverlap") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, + int, int, bool, bool>(), py::call_guard(), - py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("myrank"), py::arg("numranks"), - py::arg("mylocal"), py::arg("numlocal"), py::arg("mynode"), py::arg("numnodes"), - py::arg("tp_size"), py::arg("callbacks"), py::arg("num_splits") = 3, - py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, - py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, - py::arg("atomic_gemm") = false) + py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, + py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) .def("split_overlap_rs", &CommOverlap::split_overlap_rs, @@ -284,17 +284,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); - py::class_(m, "CommOverlapP2P", py::module_local()) - .def(py::init &, at::ScalarType, int, int, int, int, int, int, int, - CommOverlapHelper *, transformer_engine::CommOverlapType, int, int, int, - bool, bool, bool, bool>(), + py::class_(m, "CommOverlapP2P") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, + transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), py::call_guard(), - py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("myrank"), py::arg("numranks"), - py::arg("mylocal"), py::arg("numlocal"), py::arg("mynode"), py::arg("numnodes"), - py::arg("tp_size"), py::arg("callbacks"), py::arg("comm_type"), - py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, - py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, - py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("comm_cga_size") = 1, py::arg("num_comm_sm") = 1, + py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, + py::arg("use_ce") = true, py::arg("aggregate") = false) .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, py::call_guard()) .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6bb398d078..c7e60d7b12 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -99,27 +99,23 @@ def initialize_ub( _ub_communicators = {} if tex.ubuf_built_with_mpi(): - # Userbuffers will ignore all these values when it is built with MPI, so these are just - # placeholders based on an assumption that tp_size covers all devices in a physical node. + # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force + # an MPI_Init() here by creating a new MPI process group... assert torch.distributed.is_mpi_available() - mpi_group = torch.distributed.new_group(backend="mpi") - world_rank = torch.distributed.get_rank(mpi_group) - world_size = torch.distributed.get_world_size(mpi_group) - local_rank = world_rank % tp_size - local_size = tp_size - self_node_idx = world_rank // tp_size - num_nodes = world_size // tp_size - ub_callbacks = tex.CommOverlapHelper() + _ = torch.distributed.new_group(backend="mpi") + helper = tex.CommOverlapHelper() else: + # Bootstrapping with torch.distributed API, so check backend and construct + # intra/inter-node process groups... assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" if bootstrap_backend is None: bootstrap_backend = "nccl" - if torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" - elif torch.distributed.is_mpi_available(): + if torch.distributed.is_mpi_available(): bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" else: assert bootstrap_backend in ["gloo", "mpi", "nccl"] @@ -184,26 +180,33 @@ def initialize_ub( ranks_per_node_list, backend=bootstrap_backend ) local_rank = torch.distributed.get_rank(intra_node_group) - local_size = torch.distributed.get_world_size(intra_node_group) intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + ranks_per_node_tensor = torch.tensor(ranks_per_node_list, dtype=int) + ranks_across_nodes_list = ranks_per_node_tensor.transpose(0, 1).tolist() + inter_node_group, _ = torch.distirbuted.new_subgroups_by_enumeration( + ranks_across_nodes_list, backend=bootstrap_backend + ) + + helper = tex.CommOverlapHelper(world_group, intra_node_group, inter_node_group) + else: self_node_idx = 0 - intra_node_group = world_group local_rank = world_rank - local_size = world_size intra_node_ranks = list(range(world_size)) + helper = tex.CommOverlapHelper(world_group) + if world_rank == 0: - print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) + print(f"!!! [UB] Number of NVLink domains: {num_nodes}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks in NVLink domain #{self_node_idx}: {intra_node_ranks}\n", end="", flush=True, ) - ub_callbacks = tex.CommOverlapHelper(world_group, intra_node_group) + # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace @@ -303,21 +306,15 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype + buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape - dtype, # Communication buffer data type - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - ub_callbacks, # Helper for torch.distributed callbacks during bootstrapping tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, - num_max_steams=_NUM_MAX_UB_STREAMS, + num_max_streams=_NUM_MAX_UB_STREAMS, comm_cga_size=cga_size, num_comm_sm=num_sm, set_sm_margin=set_sm_margin, @@ -328,17 +325,11 @@ def add_ub( else: ub_obj = tex.CommOverlap( shape, # Communication buffer shape - dtype, # Communication buffer data type - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - ub_callbacks, # Helper for torch.distributed callbacks during bootstrapping num_splits=num_splits, - num_max_steams=_NUM_MAX_UB_STREAMS, + num_max_streams=_NUM_MAX_UB_STREAMS, comm_cga_size=cga_size, num_comm_sm=num_sm, set_sm_margin=set_sm_margin, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6dea806993..f2f909bfde 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -161,9 +161,9 @@ def forward( if not return_layernorm_output: ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif parallel_mode == "column" and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -293,7 +293,7 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -485,7 +485,7 @@ def backward( rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(grad_output.size()) @@ -496,14 +496,14 @@ def backward( ) if ub_obj_dgrad.is_p2p_overlap(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -616,7 +616,7 @@ def backward( out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -640,7 +640,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -658,7 +658,7 @@ def backward( use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c1633111d..6ae13567e5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -180,9 +180,9 @@ def forward( ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif set_parallel_mode and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -298,14 +298,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_fc2out.is_fp8_ubuf(): fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT @@ -369,7 +369,7 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -410,9 +410,9 @@ def forward( dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) @@ -615,9 +615,9 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( @@ -788,7 +788,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -797,14 +797,14 @@ def backward( rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -842,7 +842,7 @@ def backward( grad=True, gelu_input=fc1_out, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) @@ -892,7 +892,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -900,9 +900,9 @@ def backward( dim_size[1] = fc1_weight.size(1) rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -967,7 +967,7 @@ def backward( out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -991,7 +991,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -1009,7 +1009,7 @@ def backward( use_bias=not ctx.bias_gelu_nvfusion, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total, dgelu) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f521cf4fb6..1b4a903523 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -190,14 +190,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_projout.is_fp8_ubuf(): proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] @@ -269,9 +269,9 @@ def forward( dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -407,9 +407,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ( grad_output, @@ -496,7 +496,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], layout="NN", grad=True, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), From 0bd4822f5721d6c5bbe476b6538436882024a059 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 27 Aug 2024 18:48:28 +0000 Subject: [PATCH 07/34] added a pylint exception to comm+GEMM overlap test runner Signed-off-by: Alp Dener --- tests/pytorch/distributed/run_gemm_with_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index f742242386..9b177f3ebd 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -461,7 +461,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else: ref_g = torch.matmul(inp_g, ker_g) if ub_obj2 is not None: - inp2_g = torch.nn.functional.gelu(ref_g) + inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) if opts.fp8: From cb9f235d6edc52d8a9c1904b3dd2ecb4bbe6e203 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 19:10:02 +0000 Subject: [PATCH 08/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed/run_gemm_with_overlap.py | 12 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 406 +++++++++--------- .../userbuffers/userbuffers-host.cpp | 22 +- .../userbuffers/userbuffers.cu | 5 +- .../userbuffers/userbuffers.h | 26 +- .../transformer_engine/comm_gemm_overlap.h | 109 +++-- .../common/util/pybind_helper.h | 127 +++--- transformer_engine/pytorch/csrc/common.h | 4 +- transformer_engine/pytorch/csrc/extensions.h | 122 +++--- .../csrc/extensions/comm_gemm_overlap.cpp | 405 +++++++++-------- .../pytorch/csrc/extensions/pybind.cpp | 38 +- transformer_engine/pytorch/module/base.py | 2 - 12 files changed, 627 insertions(+), 651 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 9b177f3ebd..fcf003e380 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -321,9 +321,11 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) buffer_dtype = torch.bfloat16 - if (opts.fp8 + if ( + opts.fp8 and not opts.bulk_overlap - and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output)): + and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) + ): buffer_dtype = torch.uint8 ub_obj = ( tex.CommOverlapP2P( @@ -721,7 +723,11 @@ def _gemm(): "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "", "atomic " if opts.atomic else "", "GEMM", - (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == tex.CommOverlapType.RS else ""), + ( + f" + {'p2p ' if opts.p2p else ''}reduce-scatter" + if opts.comm_type == tex.CommOverlapType.RS + else "" + ), ] ) timing_info = ( diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d66d5bfa77..ec500610e4 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -4,12 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include - -#include #include #include +#include + +#include +#include #include "common/common.h" #include "common/util/cuda_driver.h" @@ -48,11 +48,11 @@ bool ubuf_built_with_mpi() { #endif } -CommOverlapCore::CommOverlapCore( - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm) { +CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { if (myrank == 0) { @@ -61,9 +61,8 @@ CommOverlapCore::CommOverlapCore( #ifdef NVTE_UB_WITH_MPI create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); #else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, allgather_handle, - barrier_handle, 1, 1, tp_size, 1); + create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + allgather_handle, barrier_handle, 1, 1, tp_size, 1); #endif _comm_created = true; } @@ -95,8 +94,8 @@ CommOverlapCore::CommOverlapCore( NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); - _counter = TensorWrapper(counter_ptr, std::vector{(size_t)_num_splits * 2}, - DType::kInt32); + _counter = + TensorWrapper(counter_ptr, std::vector{(size_t)_num_splits * 2}, DType::kInt32); } // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); @@ -123,21 +122,21 @@ CommOverlapCore::~CommOverlapCore() { #endif _comm_created = false; } - } +} /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ -CommOverlapBase::CommOverlapBase( - const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, - int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm) - : CommOverlapCore( - myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, - barrier_handle, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, - false, atomic_gemm) { +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, false, atomic_gemm) { _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", @@ -147,8 +146,7 @@ CommOverlapBase::CommOverlapBase( size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) - printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); @@ -164,11 +162,12 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, const TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, + const TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -190,8 +189,8 @@ void CommOverlapBase::bulk_overlap( assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm); } else { reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); } @@ -244,12 +243,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( assert(pre_gelu_out.numel() == 0); auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, - workspace.dtype()); - nvte_cublas_atomic_gemm( - A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, - workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, - _counter.data(), _stream_compute[0]); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), + _stream_compute[0]); for (int i = 0; i < _num_splits; i++) { if (_rs_kernel_type == 1) { @@ -261,12 +260,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, - _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_strided_atomic( - rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _num_splits, &counter_ptr[i], - _ub_comm, _stream_comm); + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + _stream_comm); } } else if (_rs_kernel_type == 2) { if (_ubuf.element_size() == 1) { @@ -277,9 +276,9 @@ void CommOverlapBase::atomic_gemm_overlap_rs( rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_strided_multiatomic( - rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, _num_splits, counter_ptr, _ub_comm, - _stream_comm); + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, + _num_splits, counter_ptr, _ub_comm, + _stream_comm); } break; } else { @@ -287,12 +286,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( if (_ubuf.element_size() == 1) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, - _stream_comm);); + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_strided( - rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm); + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm); } } @@ -309,11 +308,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, const TensorWrapper &D, + const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + const TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -344,17 +344,16 @@ void CommOverlapBase::split_overlap_rs( assert(pre_gelu_out.numel() == 0); if (gemm_overlap) { - auto input_a_chunk = TensorWrapper( - A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = TensorWrapper( - _ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto input_a_chunk = + TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = + TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); auto workspace_chunk = TensorWrapper( workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm( - input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, - _stream_compute[0]); + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { input_a_chunk_ptr += input_a_chunk_size * B.element_size(); @@ -362,23 +361,20 @@ void CommOverlapBase::split_overlap_rs( char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - input_a_chunk = TensorWrapper( - reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, A.dtype(), nullptr, nullptr, - A.scale_inv()); - output_chunk = TensorWrapper( - reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), - D.scale(), nullptr); - workspace_chunk = TensorWrapper( - reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, - workspace.dtype()); - - nvte_cublas_gemm( - input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, - _math_sms, _stream_compute[i % _stream_compute.size()]); - - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, + D.dtype(), D.amax(), D.scale(), nullptr); + workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); // Communication chunk @@ -387,20 +383,18 @@ void CommOverlapBase::split_overlap_rs( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, _stream_comm);); + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_stridedoutput( - rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, - _stream_comm); + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); } rs_output_ptr += m_chunk * rs_output.element_size(); } int last_compute_stream_id = (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); // Last communication chunk with max SM @@ -413,32 +407,29 @@ void CommOverlapBase::split_overlap_rs( rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_stridedoutput( - rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, - _stream_comm); + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm); } } else { for (int i = 0; i < _num_splits; i++) { char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto input_a_chunk = TensorWrapper( - reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, A.dtype(), nullptr, nullptr, - A.scale_inv()); - auto output_chunk = TensorWrapper( - reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), - D.scale(), nullptr); - auto workspace_chunk = TensorWrapper( - reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, - workspace.dtype()); - - nvte_cublas_gemm( - input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, - _math_sms, _stream_compute[i % _stream_compute.size()]); - - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, - _stream_compute[i % _stream_compute.size()])); + auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), + {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); // Communication chunk. Uses MAX_SM at the last chunk @@ -453,8 +444,8 @@ void CommOverlapBase::split_overlap_rs( rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { - reducescatter2_userbuff_stridedoutput( - rs_output_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm); + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); } rs_output_ptr += m_chunk * rs_output.element_size(); @@ -476,16 +467,16 @@ void CommOverlapBase::split_overlap_rs( * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ -CommOverlapP2PBase::CommOverlapP2PBase( - const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, - int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, - int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm, bool aggregate) - : CommOverlapCore( - myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, - barrier_handle, tp_size, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, - use_ce, atomic_gemm) { +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -503,17 +494,15 @@ CommOverlapP2PBase::CommOverlapP2PBase( void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) - printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { - _ubufs.push_back(TensorWrapper( - reinterpret_cast(ubuf_byte_ptr), {buffer_shape[0] / tp_size, buffer_shape[1]}, - buffer_dtype)); + _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), + {buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -591,8 +580,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, - workspace.dtype()); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -607,9 +596,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( if (_use_multiatomic_ag) { if (i == 0) { _ub_comm->use_ce = 0; - userbuffers_sendrecv_multiatomic( - _ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, _ub_comm, _next_rank, _prev_rank, - _tp_size, counter_ptr, true, _stream_recv); + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, _stream_recv); } } else { userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, @@ -619,10 +608,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( producer(counter_ptr, recv_chunk_id, _stream_recv); } if (i == 0) { - nvte_cublas_atomic_gemm( - A.data(), input_b.data(), D_buffer.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, false, _counter.data(), stream_main); + nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, + _counter.data(), stream_main); } } @@ -640,9 +629,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( // Copy the first GEMM output chunk to the end chunk position of D_buffer char *src_ptr = reinterpret_cast(D_buffer.dptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, - D_chunk_bytes, cudaMemcpyDeviceToDevice, - stream_main)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + cudaMemcpyDeviceToDevice, stream_main)); // Return the last N rows of D_buffer NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), @@ -661,11 +649,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ -void CommOverlapP2PBase::split_overlap_ag( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, + const TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -724,32 +714,31 @@ void CommOverlapP2PBase::split_overlap_ag( // GEMM char *input_b_chunk_ptr = input_b_ptr + send_offset; - auto input_b_chunk = TensorWrapper( - reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), nullptr, - nullptr, B.scale_inv()); + auto input_b_chunk = + TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper( - reinterpret_cast(output_chunk_ptr), {n_chunk * 2, m}, D.dtype(), D.amax(), - D.scale(), nullptr); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); - char *aux_chunk_ptr = (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) - : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk * 2, m} - : std::vector{0}; - auto aux_chunk = TensorWrapper( - reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, pre_gelu_out.dtype()); + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = + (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = TensorWrapper( - reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, - workspace.dtype()); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm( - A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, - _math_sms, _stream_compute[i % _stream_compute.size()]); + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); if (i < num_steps - 1) { // P2P communication @@ -759,8 +748,8 @@ void CommOverlapP2PBase::split_overlap_ag( prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - _stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); @@ -781,31 +770,29 @@ void CommOverlapP2PBase::split_overlap_ag( int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = TensorWrapper( - _ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), nullptr, nullptr, B.scale_inv()); + auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper( - reinterpret_cast(output_chunk_ptr), {n_chunk, m}, D.dtype(), D.amax(), D.scale(), - nullptr); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, + D.dtype(), D.amax(), D.scale(), nullptr); - char *aux_chunk_ptr = (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) - : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} - : std::vector{0}; - auto aux_chunk = TensorWrapper( - reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, pre_gelu_out.dtype()); + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = TensorWrapper( - reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, - workspace.dtype()); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm( - A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, - _math_sms, _stream_compute[i % _stream_compute.size()]); + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); if (i < _tp_size - 1) { // P2P communication @@ -815,8 +802,8 @@ void CommOverlapP2PBase::split_overlap_ag( _prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - _stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); @@ -866,12 +853,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( // Process GEMM chunks in the order that AG+GEMM places the output chunks. auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = TensorWrapper( - workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_atomic_gemm( - A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, - workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, 0, _tp_size, true, - _counter.data(), stream_main); + auto workspace_chunk = + TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), + stream_main); // P2P communication chunk for (int i = 1; i < _tp_size; i++) { @@ -899,7 +886,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, - _ubufs[0].numel(), stream_main);); + _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); } @@ -945,22 +932,20 @@ void CommOverlapP2PBase::split_overlap_rs( int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - auto input_b_chunk = TensorWrapper( - reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, B.dtype(), nullptr, nullptr, - B.scale_inv()); + auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, + B.dtype(), nullptr, nullptr, B.scale_inv()); - auto output_chunk = TensorWrapper( - _ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); + auto output_chunk = + TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = TensorWrapper( - reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, - workspace.dtype()); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm( - A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, - _stream_compute[i % _stream_compute.size()]); + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); if (i > 0) { // P2P communication chunk @@ -968,18 +953,17 @@ void CommOverlapP2PBase::split_overlap_rs( int recv_offset = comm_bytes * (i - 1 + _tp_size); int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, _stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); } } - for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); @@ -998,7 +982,7 @@ void CommOverlapP2PBase::split_overlap_rs( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, - _ubufs[0].numel(), stream_main);); + _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 7f6ece4eac..9e4bc7d3dc 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -106,10 +106,10 @@ int pipe_rank(communicator *comm, int step) { return newnode * numlocal + newlocal; } -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, - int pipenodes, int tensorgpus, int tensornodes) { +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { *comm = new communicator(); (*comm)->comm_world = EXT_COMM_WORLD; @@ -345,17 +345,17 @@ int create_communicator_grouped2( return 0; } -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, - int pipenodes) { +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); } -int create_communicator( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier) { +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, 1, 1, 1, 1); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 37b9053696..a45d91a387 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2499,13 +2499,12 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i // reset counters kernel static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) { if (blockIdx.x == 0 && threadIdx.x == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < num_chunks; i++) { ((unsigned int *)atomic_ptr)[i] = 1; ((unsigned int *)atomic_ptr)[i + num_chunks] = 0; } - if (allgather) - ((unsigned int *)atomic_ptr)[0] = 0; + if (allgather) ((unsigned int *)atomic_ptr)[0] = 0; } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index eb2d812824..548f452ac3 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -167,19 +167,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream); /* creates communicator, allocates all internal buffers if necessary */ -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, - int pipenodes, int tensorgpus, int tensornodes); - -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, int pipegpus, - int pipenodes); - -int create_communicator( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier); +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes); + +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes); + +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier); int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index fdf9158f57..55622d6d43 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -7,13 +7,12 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ #define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ -#include - #include #include - #include +#include + #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 @@ -62,11 +61,10 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; - CommOverlapCore( - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm); + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, + int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool use_ce, bool atomic_gemm); virtual ~CommOverlapCore(); @@ -88,12 +86,11 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; - CommOverlapBase( - const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, - int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); ~CommOverlapBase(); @@ -101,29 +98,30 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, const TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + const TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, const TensorWrapper &rs_output, + cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main); + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + const TensorWrapper &rs_output, cudaStream_t stream_main); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -144,13 +142,12 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; - CommOverlapP2PBase( - const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, - int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, - int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, - bool atomic_gemm = false, bool aggregate = false); + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, + bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); ~CommOverlapP2PBase(); @@ -161,11 +158,11 @@ class CommOverlapP2PBase : public CommOverlapCore { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void atomic_gemm_overlap_ag( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main); + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main); /* ** Split AllGather + GEMM using P2P communication @@ -174,29 +171,29 @@ class CommOverlapP2PBase : public CommOverlapCore { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void split_overlap_ag( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main); + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &B_copy, cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main); + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + const TensorWrapper &D, const TensorWrapper &bias, + const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + const TensorWrapper &rs_output, cudaStream_t stream_main); }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 5d3f02033d..a90077f935 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,71 +8,70 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include - -#include -#include #include +#include +#include -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::device_supports_multicast, \ - py::call_guard()); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::device_supports_multicast, \ + py::call_guard()); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); #endif diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 243306425f..175a7b0e90 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -22,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -39,6 +37,7 @@ #include #include +#include #include #include #include @@ -46,6 +45,7 @@ #include #include #include +#include #include #include "common/util/logging.h" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 42b576389a..d74bd4274a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -533,13 +533,12 @@ class CommOverlapHelper : torch::CustomClassHolder { void ub_barrier(char *group); }; -class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { +class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { public: - CommOverlap( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); void set_ubuf_scale_inv(torch::Tensor scale_inv) { assert(scale_inv.numel()); @@ -568,36 +567,39 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOv /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output); /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, at::Tensor rs_output); }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { public: - CommOverlapP2P( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); + CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, + bool use_ce = true, bool aggregate = false); void set_ubuf_scale_inv(torch::Tensor scale_inv) { assert(scale_inv.numel()); @@ -617,13 +619,15 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor B_copy); /* ** Split AllGather + GEMM using P2P communication @@ -632,35 +636,41 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void split_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor B_copy); /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output); + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor rs_output); /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output); + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output); }; // CommOverlapP2P >>>>>>> 7f2dcc5 (added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common) diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 9e08265e0b..3a6f3359d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ - #include "../extensions.h" +#include "../extensions.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 @@ -14,50 +14,48 @@ using namespace std::placeholders; namespace te = transformer_engine; -#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, \ - B, B_scale_inv, B_fp8_index, B_type, \ - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, \ - workspace) \ - A = A.contiguous(); \ - void *A_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(A_type)) { \ - assert(A_scale_inv.numel()); \ - A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ - } \ - auto A_ = makeTransformerEngineTensor( \ - A.data_ptr(), std::vector{(size_t)A.size(0), (size_t)A.size(1)}, A_type, nullptr, \ - nullptr, A_scale_inv_ptr); \ - B = B.contiguous(); \ - void *B_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(B_type)) { \ - assert(B_scale_inv.numel()); \ - B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ - } \ - auto B_ = makeTransformerEngineTensor( \ - B.data_ptr(), std::vector{(size_t)B.size(0), (size_t)B.size(1)}, B_type, nullptr, \ - nullptr, B_scale_inv_ptr); \ - void *D_amax_ptr = nullptr; \ - void *D_scale_ptr = nullptr; \ - if (te::is_fp8_dtype(D_type)) { \ - assert(D_amax.numel()); \ - D_amax_ptr = D_amax.data_ptr(); \ - assert(D_scale.numel()); \ - D_scale_ptr = D_scale.data_ptr(); \ - } \ - auto D_ = makeTransformerEngineTensor( \ - D.data_ptr(), std::vector{(size_t)D.size(0), (size_t)D.size(1)}, D_type, \ - D_amax_ptr, D_scale_ptr, nullptr); \ - auto bias_ = makeTransformerEngineTensor( \ - bias.data_ptr(), std::vector{(size_t)bias.size(0)}, bias_type); \ - const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ - ? std::vector{static_cast(pre_gelu_out.size(0))} \ - : std::vector{static_cast(pre_gelu_out.size(0)), \ - static_cast(pre_gelu_out.size(1))}; \ - auto pre_gelu_out_ = makeTransformerEngineTensor( \ - pre_gelu_out.data_ptr(), gelu_shape, \ - GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ - auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{(size_t)workspace.size(0)}, te::DType::kByte); +#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ + B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ + bias_type, pre_gelu_out, workspace) \ + A = A.contiguous(); \ + void *A_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(A_type)) { \ + assert(A_scale_inv.numel()); \ + A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ + } \ + auto A_ = makeTransformerEngineTensor(A.data_ptr(), \ + std::vector{(size_t)A.size(0), (size_t)A.size(1)}, \ + A_type, nullptr, nullptr, A_scale_inv_ptr); \ + B = B.contiguous(); \ + void *B_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(B_type)) { \ + assert(B_scale_inv.numel()); \ + B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ + } \ + auto B_ = makeTransformerEngineTensor(B.data_ptr(), \ + std::vector{(size_t)B.size(0), (size_t)B.size(1)}, \ + B_type, nullptr, nullptr, B_scale_inv_ptr); \ + void *D_amax_ptr = nullptr; \ + void *D_scale_ptr = nullptr; \ + if (te::is_fp8_dtype(D_type)) { \ + assert(D_amax.numel()); \ + D_amax_ptr = D_amax.data_ptr(); \ + assert(D_scale.numel()); \ + D_scale_ptr = D_scale.data_ptr(); \ + } \ + auto D_ = makeTransformerEngineTensor(D.data_ptr(), \ + std::vector{(size_t)D.size(0), (size_t)D.size(1)}, \ + D_type, D_amax_ptr, D_scale_ptr, nullptr); \ + auto bias_ = makeTransformerEngineTensor(bias.data_ptr(), \ + std::vector{(size_t)bias.size(0)}, bias_type); \ + const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ + ? std::vector{static_cast(pre_gelu_out.size(0))} \ + : std::vector{static_cast(pre_gelu_out.size(0)), \ + static_cast(pre_gelu_out.size(1))}; \ + auto pre_gelu_out_ = makeTransformerEngineTensor( \ + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ + auto workspace_ = makeTransformerEngineTensor( \ + workspace.data_ptr(), std::vector{(size_t)workspace.size(0)}, te::DType::kByte); /*************************************************************************************************** * CommOverlapHelper @@ -72,98 +70,100 @@ CommOverlapHelper::CommOverlapHelper() { CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, std::optional intra_node_group_holder, std::optional inter_node_group_holder) { -myrank = world_group->getRank(); -numranks = world_group->getSize(); -pgs.insert({"world", world_group}); -c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); -backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - -if (intra_node_group_holder.has_value()) { - NVTE_CHECK(inter_node_group_holder.has_value(), - "Internal TE error: Inter-node group cannot be `None` when intra-node group exists!"); - - // Get local rank on node and number of local ranks - c10d::ProcessGroup *intra_node_group = inter_node_group_holder.value(); - NVTE_CHECK(intra_node_group->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - mylocal = intra_node_group->getRank(); - numlocal = intra_node_group->getSize(); - pgs.insert({"intra", intra_node_group}); - - // Get node ID and number of nodes - c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); - NVTE_CHECK(inter_node_group->getBackendType() == backend, - "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - mynode = inter_node_group->getRank(); - numnodes = inter_node_group->getSize(); -} else { - // There is only one node so local rank/size is equal to global rank/size - mylocal = myrank; - numlocal = numranks; - pgs.insert({"intra", world_group}); - - mynode = 0; - numnodes = 1; -} + myrank = world_group->getRank(); + numranks = world_group->getSize(); + pgs.insert({"world", world_group}); + c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + + if (intra_node_group_holder.has_value()) { + NVTE_CHECK( + inter_node_group_holder.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group exists!"); + + // Get local rank on node and number of local ranks + c10d::ProcessGroup *intra_node_group = inter_node_group_holder.value(); + NVTE_CHECK(intra_node_group->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + mylocal = intra_node_group->getRank(); + numlocal = intra_node_group->getSize(); + pgs.insert({"intra", intra_node_group}); + + // Get node ID and number of nodes + c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); + NVTE_CHECK(inter_node_group->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + mynode = inter_node_group->getRank(); + numnodes = inter_node_group->getSize(); + } else { + // There is only one node so local rank/size is equal to global rank/size + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); -initialized = true; + mynode = 0; + numnodes = 1; + } + + initialized = true; } CommOverlapHelper::~CommOverlapHelper() { -for (auto &pg : pgs) pg.second = nullptr; -backend_is_nccl = false; -initialized = false; + for (auto &pg : pgs) pg.second = nullptr; + backend_is_nccl = false; + initialized = false; } void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, char *group) { -NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); - -auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); -auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; -auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); -auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - -std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; -std::vector localchunk = {localtmp}; -auto work = pgs[group]->allgather(globalchunks, localchunk); -work->wait(); - -if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); -} + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } } void CommOverlapHelper::ub_barrier(char *group) { -NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); -auto work = pgs[group]->barrier(); -work->wait(); + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); } /*************************************************************************************************** * CommOverlap **************************************************************************************************/ -CommOverlap::CommOverlap( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) {} +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, + helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { +} /* ** Bulk GEMM + COMM @@ -177,16 +177,15 @@ std::vector CommOverlap::bulk_overlap( transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, te::CommOverlapType comm_type, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto rs_out_ = makeTransformerEngineTensor(rs_output); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::bulk_overlap( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, comm_type, rs_out_, stream_main); + te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, + grad, accumulate, use_split_accumulator, comm_type, rs_out_, + stream_main); // Get the current userbuf offset char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); @@ -195,12 +194,12 @@ std::vector CommOverlap::bulk_overlap( } // Generate output tensor from userbuf data pointer - int output_c_dim0 = (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) - : _ubuf.size(0) / _tp_size; + int output_c_dim0 = + (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); - auto output_tensor = torch::from_blob( - ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + auto output_tensor = + torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); return {D, output_tensor}; } // CommOverlap::bulk_overlap @@ -216,39 +215,39 @@ void CommOverlap::atomic_gemm_overlap_rs( transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto rs_out_ = makeTransformerEngineTensor(rs_output); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::atomic_gemm_overlap_rs( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, gemm_overlap, rs_out_, stream_main); + te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); } // CommOverlap::split_overlap_rs /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlap::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) +void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto rs_out_ = makeTransformerEngineTensor(rs_output); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::split_overlap_rs( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, gemm_overlap, rs_out_, stream_main); + te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); } // CommOverlap::split_overlap_rs /* @@ -273,9 +272,8 @@ void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { @@ -285,30 +283,28 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { NVTE_ERROR("Invalid comm_type"); if (_comm_type == te::CommOverlapType::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) - : _ubuf.size(0) / _tp_size; + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); - return torch::from_blob( - ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); } /*************************************************************************************************** * CommOverlapP2P **************************************************************************************************/ -CommOverlapP2P::CommOverlapP2P( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm, bool use_ce, bool aggregate) : te::CommOverlapP2PBase( buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, - aggregate) {} + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} /* ** Split AllGather + AtomicGEMM using P2P communication @@ -324,16 +320,15 @@ void CommOverlapP2P::atomic_gemm_overlap_ag( at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto B_copy_ = makeTransformerEngineTensor(B_copy); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_ag( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); + te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, B_copy_, stream_main); } // atomic_gemm_overlap_ag /* @@ -350,16 +345,15 @@ void CommOverlapP2P::split_overlap_ag( at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto B_copy_ = makeTransformerEngineTensor(B_copy); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_ag( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); + te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + B_copy_, stream_main); } // split_overlap_ag /* @@ -372,16 +366,15 @@ void CommOverlapP2P::atomic_gemm_overlap_rs( at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto rs_out_ = makeTransformerEngineTensor(rs_output); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_rs( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); + te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, rs_out_, stream_main); } /* @@ -394,16 +387,15 @@ void CommOverlapP2P::split_overlap_rs( at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, - B, B_scale_inverse, B_fp8_tensor, B_type, - D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, - workspace) + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) auto rs_out_ = makeTransformerEngineTensor(rs_output); cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_rs( - A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); + te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + rs_out_, stream_main); } /* @@ -418,16 +410,16 @@ void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { NVTE_ERROR("input and ubuf size do not match!"); } NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); } else { if (input.numel() != (int64_t)_ubuf.numel() || input.element_size() != (int64_t)_ubuf.element_size()) { NVTE_ERROR("input and ubuf size do not match!"); } NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); } } @@ -438,10 +430,9 @@ torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { NVTE_ERROR("Invalid comm_type"); if (_comm_type == te::CommOverlapType::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) - : _ubuf.size(0) / _tp_size; + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); - return torch::from_blob( - ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 5e7d173d19..39679ed669 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -7,9 +7,8 @@ #include #include -#include "common/util/pybind_helper.h" - #include "../extensions.h" +#include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) @@ -253,20 +252,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def(py::init<>(), py::call_guard()) .def(py::init, std::optional>(), - py::call_guard(), - py::arg("world_group"), py::arg("intra_node_group") = py::none(), - py::arg("inter_node_group") = py::none()); + py::call_guard(), py::arg("world_group"), + py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); py::class_(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, int, int, bool, bool>(), - py::call_guard(), - py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) - .def("bulk_overlap", &CommOverlap::bulk_overlap, - py::call_guard()) + .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) .def("split_overlap_rs", &CommOverlap::split_overlap_rs, py::call_guard()) .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, @@ -277,22 +274,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()) .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, py::call_guard()) - .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, - py::call_guard()) - .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, - py::call_guard()); + .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) + .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) + .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); py::class_(m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), - py::call_guard(), - py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), - py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, - py::arg("comm_cga_size") = 1, py::arg("num_comm_sm") = 1, - py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, + py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, + py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, py::call_guard()) .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, @@ -307,8 +300,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()) .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, py::call_guard()) - .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, - py::call_guard()) + .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, py::call_guard()) .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c7e60d7b12..cf68d5a47f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -206,8 +206,6 @@ def initialize_ub( flush=True, ) - - # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) From aaca9d847f393a91b84ca3dc2b82edea679eda6b Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 27 Aug 2024 21:06:43 +0000 Subject: [PATCH 09/34] fixing linting errors Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 +-- .../csrc/extensions/comm_gemm_overlap.cpp | 25 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec500610e4..c918fabefa 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -94,8 +94,8 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); - _counter = - TensorWrapper(counter_ptr, std::vector{(size_t)_num_splits * 2}, DType::kInt32); + _counter = TensorWrapper( + counter_ptr, std::vector{static_cast(_num_splits * 2)}, DType::kInt32); } // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 3a6f3359d3..6290f935b3 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -23,18 +23,18 @@ namespace te = transformer_engine; assert(A_scale_inv.numel()); \ A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ } \ - auto A_ = makeTransformerEngineTensor(A.data_ptr(), \ - std::vector{(size_t)A.size(0), (size_t)A.size(1)}, \ - A_type, nullptr, nullptr, A_scale_inv_ptr); \ + auto A_ = makeTransformerEngineTensor( \ + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ + nullptr, nullptr, A_scale_inv_ptr); \ B = B.contiguous(); \ void *B_scale_inv_ptr = nullptr; \ if (te::is_fp8_dtype(B_type)) { \ assert(B_scale_inv.numel()); \ B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ } \ - auto B_ = makeTransformerEngineTensor(B.data_ptr(), \ - std::vector{(size_t)B.size(0), (size_t)B.size(1)}, \ - B_type, nullptr, nullptr, B_scale_inv_ptr); \ + auto B_ = makeTransformerEngineTensor( \ + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ + nullptr, nullptr, B_scale_inv_ptr); \ void *D_amax_ptr = nullptr; \ void *D_scale_ptr = nullptr; \ if (te::is_fp8_dtype(D_type)) { \ @@ -43,11 +43,11 @@ namespace te = transformer_engine; assert(D_scale.numel()); \ D_scale_ptr = D_scale.data_ptr(); \ } \ - auto D_ = makeTransformerEngineTensor(D.data_ptr(), \ - std::vector{(size_t)D.size(0), (size_t)D.size(1)}, \ - D_type, D_amax_ptr, D_scale_ptr, nullptr); \ - auto bias_ = makeTransformerEngineTensor(bias.data_ptr(), \ - std::vector{(size_t)bias.size(0)}, bias_type); \ + auto D_ = makeTransformerEngineTensor( \ + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ + D_amax_ptr, D_scale_ptr, nullptr); \ + auto bias_ = makeTransformerEngineTensor( \ + bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ ? std::vector{static_cast(pre_gelu_out.size(0))} \ : std::vector{static_cast(pre_gelu_out.size(0)), \ @@ -55,7 +55,8 @@ namespace te = transformer_engine; auto pre_gelu_out_ = makeTransformerEngineTensor( \ pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{(size_t)workspace.size(0)}, te::DType::kByte); + workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ + te::DType::kByte); /*************************************************************************************************** * CommOverlapHelper From 45acb5e4e26c8ad90c28447f957a3be34f88ef19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:07:07 +0000 Subject: [PATCH 10/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index c918fabefa..72a24a2add 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -94,8 +94,8 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); - _counter = TensorWrapper( - counter_ptr, std::vector{static_cast(_num_splits * 2)}, DType::kInt32); + _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, + DType::kInt32); } // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); From cbd22f2c015f35345d60f1ef8ba08830a72a4d4f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 27 Aug 2024 22:23:39 +0000 Subject: [PATCH 11/34] added documentation for te.initialize_ub Signed-off-by: Alp Dener --- build_tools/pytorch.py | 3 + docs/api/pytorch.rst | 4 + .../userbuffers/userbuffers-host.cpp | 8 +- transformer_engine/pytorch/module/base.py | 147 ++++++++++++++++-- 4 files changed, 148 insertions(+), 14 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index ba1827c6bb..0f9bc7c07d 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -56,6 +56,9 @@ def setup_pytorch_extension( "--expt-extended-lambda", "--use_fast_math", ] + if bool(os.getenv("NVTE_UB_WITH_MPI", "0")): + cxx_flags.append("-DNVTE_UB_WITH_MPI") + nvcc_flags.append("-DNVTE_UB_WITH_MPI") cuda_architectures = cuda_archs() diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index b097f14475..32b60f5bfb 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -51,3 +51,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute + +.. autoapifunction:: transformer_engine.initialize_ub + +.. autoapifunction:: transformer_engine.destroy_ub diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 9e4bc7d3dc..f97f82b527 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -299,11 +299,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - NVTE_CHECK_CUDA( - cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet - NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); + // NVTE_CHECK_CUDA( + // cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet + // NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cf68d5a47f..6426e9714a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -84,10 +84,139 @@ def initialize_ub( tp_size: int, use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = None, + ub_cfgs: Optional[dict] = "nccl", bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: - """Initialize communicators for TP comm overlap using userbuffers.""" + r""" + Initialize the Userbuffers communicator for overlapping tensor-parallel communications with + GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + + Parameters + ---------- + shape : list + shape of the communication buffer, typically set to be the same as the global shape of + the input tensor to a te.TransformerLayer forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + tp_size : int + number of GPUs in the tensor-parallel process group + use_fp8 : bool = False + allocate the communication buffer for FP8 GEMM inputs/outputs + dtype : torch.dtype = torch.bfloat16 + non-FP8 data type of the communication buffer when `use_fp8 = False` + ub_cfgs: dict = None + Configuration dictionary containing a Userbuffers options for each GEMM layer in a + te.TransformerLayer. Layers that are not configured by the user fall back on the + default options below: + { + "qkv_fprop": { + "method": "ring_exchange", + "is_reduce_scatter": False, + "num_sm": 1, + "cga_size": 1, + "set_sm_margin": False, + "num_splits": tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": True, + }, + "qkv_dgrad": { + "method": "bulk", + "is_reduce_scatter": False, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "fp8_buf": True, + }, + "qkv_wgrad": { + "method": "bulk", + "is_reduce_scatter": True, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "fp8_buf": True, + }, + "proj_fprop": { + "method": "pipeline", + "is_reduce_scatter": True, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "num_splits": 4, + "atomic_gemm": False, + "fp8_buf": False, + }, + "proj_dgrad": { + "method": "ring_exchange", + "is_reduce_scatter": False, + "num_sm": 1, + "cga_size": 1, + "set_sm_margin": False, + "num_splits": tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": True, + }, + "fc1_fprop": { + "method": "ring_exchange", + "is_reduce_scatter": False, + "num_sm": 1, + "cga_size": 1, + "set_sm_margin": False, + "num_splits": tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": True, + }, + "fc1_dgrad": { + "method": "bulk", + "is_reduce_scatter": False, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "fp8_buf": True, + }, + "fc1_wgrad": { + "method": "bulk", + "is_reduce_scatter": True, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "fp8_buf": False, + }, + "fc2_fprop": { + "method": "pipeline", + "is_reduce_scatter": True, + "num_sm": 16, + "cga_size": 2, + "set_sm_margin": True, + "num_splits": 4, + "atomic_gemm": False, + "fp8_buf": False, + }, + "fc2_dgrad": { + "method": "ring_exchange", + "is_reduce_scatter": False, + "num_sm": 1, + "cga_size": 1, + "set_sm_margin": False, + "num_splits": tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": True, + }, + } + bootstrap_backend : str = "nccl" + `torch.distributed` communication backend for the all-gather, broadcast and + barrier collectives during Userbuffers initialization. Not all backends are + valid for every cluster configuration and distributed launch method even if + they are available in PyTorch. Setting `NVTE_UB_WITH_MPI=1` when building + TE overrides this option and always initializes Userbuffers with direct MPI + calls in C++, which requires `MPI_HOME` to be set at compile time. + """ if not tex.device_supports_multicast(): assert bool(os.getenv("UB_SKIPMC", "0")), ( "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " @@ -110,14 +239,12 @@ def initialize_ub( assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" - if bootstrap_backend is None: - bootstrap_backend = "nccl" - if torch.distributed.is_mpi_available(): - bootstrap_backend = "mpi" - elif torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" - else: - assert bootstrap_backend in ["gloo", "mpi", "nccl"] + assert ( + torch.distributed.is_backend_available(bootstrap_backend) + ), ( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to bootstrap " + + f"Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) From 557abbc054452187a5c339b7026d0604f0bd2125 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:24:12 +0000 Subject: [PATCH 12/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6426e9714a..a8d9925145 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -239,9 +239,7 @@ def initialize_ub( assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" - assert ( - torch.distributed.is_backend_available(bootstrap_backend) - ), ( + assert torch.distributed.is_backend_available(bootstrap_backend), ( f"PyTorch must be compiled with '{bootstrap_backend}' support in order to bootstrap " + f"Userbuffers with '{bootstrap_backend}' collectives." ) From 59c1ced2912a86f5566bc67a7436c82e50a3d04c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 27 Aug 2024 22:56:33 +0000 Subject: [PATCH 13/34] fixed compile errors when building with NVTE_UB_WITH_MPI=1 Signed-off-by: Alp Dener --- build_tools/pytorch.py | 13 +- .../distributed/run_gemm_with_overlap.py | 1 + transformer_engine/common/CMakeLists.txt | 1 + .../userbuffers/userbuffers-host.cpp | 28 +-- .../userbuffers/userbuffers.h | 10 +- transformer_engine/pytorch/csrc/extensions.h | 11 +- .../csrc/extensions/comm_gemm_overlap.cpp | 168 ++++++++++-------- 7 files changed, 126 insertions(+), 106 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 0f9bc7c07d..f438d203d6 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -11,7 +11,6 @@ from .utils import ( all_files_in_dir, cuda_archs, - cuda_path, cuda_version, ) @@ -56,9 +55,6 @@ def setup_pytorch_extension( "--expt-extended-lambda", "--use_fast_math", ] - if bool(os.getenv("NVTE_UB_WITH_MPI", "0")): - cxx_flags.append("-DNVTE_UB_WITH_MPI") - nvcc_flags.append("-DNVTE_UB_WITH_MPI") cuda_architectures = cuda_archs() @@ -85,6 +81,15 @@ def setup_pytorch_extension( continue # Already handled nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) + if os.getenv("NVTE_UB_WITH_MPI") is not None: + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") + cxx_flags.append("-DNVTE_UB_WITH_MPI") + nvcc_flags.append("-DNVTE_UB_WITH_MPI") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index fcf003e380..b00b8cc042 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format from transformer_engine.pytorch.fp8 import _default_sf_compute +warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 4e01dbd710..ca23008edd 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -103,6 +103,7 @@ if (NVTE_UB_WITH_MPI) find_package(MPI REQUIRED) target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() # Hack to enable dynamic loading in cuDNN frontend diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index f97f82b527..c6f11005a4 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -44,31 +44,19 @@ static MPI_Comm EXT_COMM_INTER; } while (false) void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - ExtComm group) { - // UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, - // globaldata, globalbytes, MPI_BYTE, - // static_cast(group))); - MPI_Comm comm = static_cast(group); + ExtComm comm) { int numranks; UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); assert(globalbytes == numranks * localbytes); - - int myrank; - UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); - char *globaltarget = reinterpret_cast(globaldata) + (myrank * localbytes); - memcpy(globaltarget, localdata, localbytes); - - for (int n = 0; n < numranks; n++) { - globaltarget = reinterpret_cast(globaldata) + (n * localbytes); - UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm)); - } + UB_MPI_CHECK( + MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm)); } -void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast(group))); } +void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else -static char EXT_COMM_WORLD[] = "world"; -static char EXT_COMM_INTRA[] = "intra"; -static char EXT_COMM_INTER[] = "inter"; +#define EXT_COMM_WORLD "world" +#define EXT_COMM_INTRA "intra" +#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 @@ -426,7 +414,7 @@ int create_communicator_mpi(communicator **comm) { void destroy_communicator(communicator *comm) { for (int hndl = 0; hndl < comm->free_region; hndl++) { - if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { + if (comm->use_mc && comm->mem_dealloc[hndl]) { for (int rank = 0; rank < comm->nvsize; rank++) { if (rank == comm->nvrank) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 548f452ac3..840cb468d2 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -19,9 +19,9 @@ #ifdef NVTE_UB_WITH_MPI #include -typedef MPI_Comm ExtComm; +#define ExtComm MPI_Comm #else -typedef char *ExtComm; +#define ExtComm const char * #endif #define ExtAllgatherOp std::function @@ -148,9 +148,9 @@ struct communicator { ExtAllgatherOp _allgather; ExtBarrierOp _barrier; - ExtComm comm_world, - comm_inter, // reduction group communicator (subset of the nodes) along GPU rail - comm_intra; // full intranode (all ndev GPUS) + ExtComm comm_world; + ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail + ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; #endif diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d74bd4274a..4478ddbf81 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -517,7 +517,12 @@ class CommOverlapHelper : torch::CustomClassHolder { std::map pgs; public: - int myrank, numranks, mylocal, numlocal, mynode, numnodes; + int myrank = -1; + int numranks = -1; + int mylocal = -1; + int numlocal = -1; + int mynode = -1; + int numnodes = -1; CommOverlapHelper(); @@ -528,9 +533,9 @@ class CommOverlapHelper : torch::CustomClassHolder { ~CommOverlapHelper(); void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - char *group); + ExtComm comm); - void ub_barrier(char *group); + void ub_barrier(ExtComm comm); }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 6290f935b3..b14c1f128e 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -71,6 +71,7 @@ CommOverlapHelper::CommOverlapHelper() { CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, std::optional intra_node_group_holder, std::optional inter_node_group_holder) { +#ifndef NVTE_UB_WITH_MPI myrank = world_group->getRank(); numranks = world_group->getSize(); pgs.insert({"world", world_group}); @@ -78,10 +79,6 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); if (intra_node_group_holder.has_value()) { - NVTE_CHECK( - inter_node_group_holder.has_value(), - "Internal TE error: Inter-node group cannot be `None` when intra-node group exists!"); - // Get local rank on node and number of local ranks c10d::ProcessGroup *intra_node_group = inter_node_group_holder.value(); NVTE_CHECK(intra_node_group->getBackendType() == backend, @@ -91,15 +88,32 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, numlocal = intra_node_group->getSize(); pgs.insert({"intra", intra_node_group}); - // Get node ID and number of nodes - c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); - NVTE_CHECK(inter_node_group->getBackendType() == backend, - "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - mynode = inter_node_group->getRank(); - numnodes = inter_node_group->getSize(); + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); + mynode = 0; + numnodes = 1; + } else { + // Intra-node group is different than the world group so there must be multiple nodes + NVTE_CHECK( + inter_node_group_holder.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", + "identical to the world_group!"); + + // Get node ID and number of nodes + c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); + NVTE_CHECK( + inter_node_group->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", world_group->getBackendName()); + mynode = inter_node_group->getRank(); + numnodes = inter_node_group->getSize(); + } } else { - // There is only one node so local rank/size is equal to global rank/size + // Intra-node group is not set so we assume there is only 1 node mylocal = myrank; numlocal = numranks; pgs.insert({"intra", world_group}); @@ -109,16 +123,23 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, } initialized = true; +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); +#endif } CommOverlapHelper::~CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI for (auto &pg : pgs) pg.second = nullptr; backend_is_nccl = false; initialized = false; +#endif } void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, - size_t localbytes, char *group) { + size_t localbytes, ExtComm group) { +#ifndef NVTE_UB_WITH_MPI NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", "with valid process groups!"); @@ -141,29 +162,38 @@ void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void globaltmp = torch::Tensor(); localtmp = torch::Tensor(); } +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif } -void CommOverlapHelper::ub_barrier(char *group) { +void CommOverlapHelper::ub_barrier(ExtComm group) { +#ifndef NVTE_UB_WITH_MPI NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", "with valid process groups!"); auto work = pgs[group]->barrier(); work->wait(); +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif } /*************************************************************************************************** * CommOverlap **************************************************************************************************/ -CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { +CommOverlap::CommOverlap( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, + int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { } /* @@ -171,11 +201,10 @@ CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType ** This function assumes the communication input is pre-copied to _ubuf */ std::vector CommOverlap::bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, te::CommOverlapType comm_type, at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, @@ -209,11 +238,10 @@ std::vector CommOverlap::bulk_overlap( ** Split FPROP GEMM + ReduceScatter */ void CommOverlap::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, @@ -230,16 +258,13 @@ void CommOverlap::atomic_gemm_overlap_rs( /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { +void CommOverlap::split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, workspace) @@ -295,17 +320,16 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { * CommOverlapP2P **************************************************************************************************/ -CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm, bool use_ce, bool aggregate) +CommOverlapP2P::CommOverlapP2P( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, + int tp_size, te::CommOverlapType comm_type, int num_max_streams, int comm_cga_size, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} /* ** Split AllGather + AtomicGEMM using P2P communication @@ -315,11 +339,10 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal *phases. */ void CommOverlapP2P::atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, @@ -340,11 +363,10 @@ void CommOverlapP2P::atomic_gemm_overlap_ag( *phases. */ void CommOverlapP2P::split_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, @@ -361,11 +383,10 @@ void CommOverlapP2P::split_overlap_ag( ** Split ReduceScatter + GEMM using P2P communication */ void CommOverlapP2P::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, @@ -382,11 +403,10 @@ void CommOverlapP2P::atomic_gemm_overlap_rs( ** Split ReduceScatter + GEMM using P2P communication */ void CommOverlapP2P::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, From e85062bac45200018235465938e319f97cb14792 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:32:35 +0000 Subject: [PATCH 14/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../csrc/extensions/comm_gemm_overlap.cpp | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index b14c1f128e..f05b7a1d75 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -184,16 +184,16 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { * CommOverlap **************************************************************************************************/ -CommOverlap::CommOverlap( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, - int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, + helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { } /* @@ -258,13 +258,15 @@ void CommOverlap::atomic_gemm_overlap_rs( /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlap::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { +void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + te::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, pre_gelu_out, workspace) @@ -320,16 +322,17 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { * CommOverlapP2P **************************************************************************************************/ -CommOverlapP2P::CommOverlapP2P( - const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, - int tp_size, te::CommOverlapType comm_type, int num_max_streams, int comm_cga_size, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, bool aggregate) +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm, bool use_ce, bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} /* ** Split AllGather + AtomicGEMM using P2P communication From 502b217fbf41d8cbf22dda3a687d3c5c6fc19d54 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 28 Aug 2024 19:50:50 +0000 Subject: [PATCH 15/34] fixed default bootstrap backend Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a8d9925145..e754ca617a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -84,8 +84,8 @@ def initialize_ub( tp_size: int, use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = "nccl", - bootstrap_backend: Union[str, torch.distributed.Backend] = None, + ub_cfgs: Optional[dict] = None, + bootstrap_backend: Union[str, torch.distributed.Backend] = "nccl", ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with @@ -309,7 +309,7 @@ def initialize_ub( ranks_per_node_tensor = torch.tensor(ranks_per_node_list, dtype=int) ranks_across_nodes_list = ranks_per_node_tensor.transpose(0, 1).tolist() - inter_node_group, _ = torch.distirbuted.new_subgroups_by_enumeration( + inter_node_group, _ = torch.distributed.new_subgroups_by_enumeration( ranks_across_nodes_list, backend=bootstrap_backend ) From b3cdf29f53d4c440736b441eb7d89a115efbb0a3 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 28 Aug 2024 16:45:30 -0700 Subject: [PATCH 16/34] switched default bootstrap backend priority to MPI > Gloo > NCCL Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/base.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e754ca617a..c2c520bd1d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -85,7 +85,7 @@ def initialize_ub( use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, ub_cfgs: Optional[dict] = None, - bootstrap_backend: Union[str, torch.distributed.Backend] = "nccl", + bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with @@ -239,10 +239,18 @@ def initialize_ub( assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" - assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to bootstrap " - + f"Userbuffers with '{bootstrap_backend}' collectives." - ) + if bootstrap_backend is None: + bootstrap_backend = "nccl" + if torch.distributed.is_mpi_available(): + bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" + else: + assert bootstrap_backend in ["mpi", "gloo", "nccl"], "Invalid bootstrap backend!" + assert torch.distributed.is_backend_available(bootstrap_backend), ( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to bootstrap " + + f"Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) From e4679edea233db20631763a193e5d833a42bf6fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 23:45:58 +0000 Subject: [PATCH 17/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c2c520bd1d..b80d87a612 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -248,8 +248,8 @@ def initialize_ub( else: assert bootstrap_backend in ["mpi", "gloo", "nccl"], "Invalid bootstrap backend!" assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to bootstrap " - + f"Userbuffers with '{bootstrap_backend}' collectives." + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." ) world_group = torch.distributed.new_group(backend=bootstrap_backend) From f675d136e84aab8879dc3ebd62ed2b6a621bfd6d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 28 Aug 2024 16:51:21 -0700 Subject: [PATCH 18/34] updated bootstrap backend documentation Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b80d87a612..f22cecd2a8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -209,13 +209,15 @@ def initialize_ub( "fp8_buf": True, }, } - bootstrap_backend : str = "nccl" + bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are valid for every cluster configuration and distributed launch method even if - they are available in PyTorch. Setting `NVTE_UB_WITH_MPI=1` when building - TE overrides this option and always initializes Userbuffers with direct MPI - calls in C++, which requires `MPI_HOME` to be set at compile time. + they are available in PyTorch. When left unset, the initialization prefers + to use the MPI backend, falling back first on Gloo and then NCCL if MPI is + not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + option and always initializes Userbuffers with direct MPI calls in C++, + which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. """ if not tex.device_supports_multicast(): assert bool(os.getenv("UB_SKIPMC", "0")), ( From 3517bebb3171bc8938fb83d7ce795f0b888a24ea Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 29 Aug 2024 15:46:54 +0000 Subject: [PATCH 19/34] close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f22cecd2a8..269547f574 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -285,9 +285,8 @@ def initialize_ub( else: ifname_warning = ( f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - " attempt to " - + "detect ranks on the same node by matching 'socket.gethostname()', which is " - + "known to fail on virtual clusters like Kubernetes. If Userbuffers " + " attempt to detect ranks on the same node by matching 'socket.gethostname()', " + + "which is known to fail on virtual clusters like Kubernetes. If Userbuffers " + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + "your environment to the correct network interface." ) From 776ad27f2a853a052f7b8e0cf7c96d55215cd315 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 29 Aug 2024 22:12:25 +0000 Subject: [PATCH 20/34] added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold Signed-off-by: Alp Dener --- transformer_engine/pytorch/csrc/extensions.h | 7 +++++ .../csrc/extensions/comm_gemm_overlap.cpp | 31 ++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4478ddbf81..5d66004caf 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -539,6 +539,10 @@ class CommOverlapHelper : torch::CustomClassHolder { }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; + public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, @@ -598,6 +602,9 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index f05b7a1d75..7d0d168749 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -194,6 +194,17 @@ CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob( + _counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } } /* @@ -228,8 +239,7 @@ std::vector CommOverlap::bulk_overlap( (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); auto output_tensor = - torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); return {D, output_tensor}; } // CommOverlap::bulk_overlap @@ -332,7 +342,19 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob( + _counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } +} /* ** Split AllGather + AtomicGEMM using P2P communication @@ -457,6 +479,5 @@ torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { int output_c_dim0 = (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); } From ce9c34d7a94a9e3a8f8b699aeb0628dd511e95b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 14:19:41 +0000 Subject: [PATCH 21/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 1 + .../pytorch/csrc/extensions/comm_gemm_overlap.cpp | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 5d66004caf..e97e042557 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -605,6 +605,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm private: torch::Tensor _ubuf_torch; torch::Tensor _ubuf_counter; + public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 7d0d168749..67214dbebf 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -201,9 +201,8 @@ CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, at::device(torch::kCUDA).dtype(buffer_dtype)); if (_atomic_gemm) { - _ubuf_counter = torch::from_blob( - _counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); } } @@ -350,9 +349,8 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, at::device(torch::kCUDA).dtype(buffer_dtype)); if (_atomic_gemm) { - _ubuf_counter = torch::from_blob( - _counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); } } From 935a4035804397b28d95a68e9483d685b48f6159 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 6 Sep 2024 15:51:23 +0000 Subject: [PATCH 22/34] automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures Signed-off-by: Alp Dener --- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../csrc/extensions/comm_gemm_overlap.cpp | 35 ++++---- transformer_engine/pytorch/module/base.py | 85 ++++++++++--------- 3 files changed, 61 insertions(+), 62 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e97e042557..1ea88d3e46 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -528,7 +528,7 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(c10d::ProcessGroup *world_group, std::optional intra_node_group_holder, - std::optional inter_node_group_holder); + std::optional inter_node_group_holde); ~CommOverlapHelper(); @@ -685,6 +685,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output); }; // CommOverlapP2P ->>>>>>> 7f2dcc5 (added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common) #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 67214dbebf..799e8b3189 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -69,24 +69,23 @@ CommOverlapHelper::CommOverlapHelper() { } // empty constructor for NVTE_UB_WITH_MPI=1 CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group_holder, - std::optional inter_node_group_holder) { + std::optional intra_domain_group, + std::optional inter_domain_group) { #ifndef NVTE_UB_WITH_MPI - myrank = world_group->getRank(); - numranks = world_group->getSize(); pgs.insert({"world", world_group}); - c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); + myrank =pgs["world"]->getRank(); + numranks =pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend =pgs["world"]->getBackendType(); backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - if (intra_node_group_holder.has_value()) { + if (intra_domain_group.has_value()) { // Get local rank on node and number of local ranks - c10d::ProcessGroup *intra_node_group = inter_node_group_holder.value(); - NVTE_CHECK(intra_node_group->getBackendType() == backend, + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - mylocal = intra_node_group->getRank(); - numlocal = intra_node_group->getSize(); - pgs.insert({"intra", intra_node_group}); + "group!",pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); if (numlocal == numranks) { // Intra-node group is same as the world group so there can only be 1 node @@ -99,18 +98,18 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, } else { // Intra-node group is different than the world group so there must be multiple nodes NVTE_CHECK( - inter_node_group_holder.has_value(), + inter_domain_group.has_value(), "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", "identical to the world_group!"); // Get node ID and number of nodes - c10d::ProcessGroup *inter_node_group = intra_node_group_holder.value(); NVTE_CHECK( - inter_node_group->getBackendType() == backend, + inter_domain_group.value()->getBackendType() == backend, "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - mynode = inter_node_group->getRank(); - numnodes = inter_node_group->getSize(); + "group!", pgs["world"]->getBackendName()); + pgs.insert({"inter", inter_domain_group.value()}); + mynode = pgs["inter"]->getRank(); + numnodes = pgs["inter"]->getSize(); } } else { // Intra-node group is not set so we assume there is only 1 node diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 269547f574..2cd2165942 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -248,7 +248,9 @@ def initialize_ub( elif torch.distributed.is_gloo_available(): bootstrap_backend = "gloo" else: - assert bootstrap_backend in ["mpi", "gloo", "nccl"], "Invalid bootstrap backend!" + assert bootstrap_backend in ["gloo", "mpi", "nccl"], ( + "Invalid torch.distributed backend for bootstrapping Userbuffers!" + ) assert torch.distributed.is_backend_available(bootstrap_backend), ( f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." @@ -258,17 +260,16 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # Construct an intra-node communicator based on global ranks that share the same hostname + # We have single-node NVLink so we can color based on physical node hostnames. # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host - # address on that interface instead of the hostname. This can help avoid issues when - # different hosts have the same hostname on Kubernetes clusters. - hostname = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), - ) - + # address on that interface instead of the hostname. Otherwise, allow the user to + # set a network interface via NVTE_UB_SOCKET_IFNAME variable. This can help avoid + # issues when different hosts have the same hostname on managed clusters. + mydomain = socket.gethostname() + ifname = os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME", + os.getenv("NVTE_UB_SOCKET_IFNAME")) if ifname is not None: +<<<<<<< HEAD # Make sure the ifname found in the environment is a valid network interface if ifname in [name for _, name in socket.if_nameindex()]: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -289,53 +290,53 @@ def initialize_ub( + "which is known to fail on virtual clusters like Kubernetes. If Userbuffers " + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + "your environment to the correct network interface." - ) warnings.warn(ifname_warning, UserWarning) - hostnames = [None for _ in range(world_size)] - torch.distributed.all_gather_object(hostnames, hostname, world_group) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - num_nodes = len(unique_hosts) - - if num_nodes > 1: - ranks_per_node_list = [[] for _ in range(num_nodes)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0, "Internal TE error!" - - intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( - ranks_per_node_list, backend=bootstrap_backend + # Allgather the domain colors across ranks and reduce to a list of unique domains + domain_per_rank_list = [None for _ in range(world_size)] + torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) + unique_domains = [] + for domain in domain_per_rank_list: + if domain not in unique_domains: + unique_domains.append(domain) + num_domains = len(unique_domains) + + if num_domains > 1: + # DP/TP model replicated on multiple NVLink domains + ranks_per_domain_list = [[] for _ in range(num_domains)] + mydomain_idx = -1 + for i, domain in enumerate(domain_per_rank_list): + domain_idx = unique_domains.index(domain) + ranks_per_domain_list[domain_idx].append(i) + if domain == mydomain: + mydomain_idx = domain_idx + assert mydomain_idx >= 0, "Internal TE error!" + + intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list, backend=bootstrap_backend ) - local_rank = torch.distributed.get_rank(intra_node_group) - intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + local_rank = torch.distributed.get_rank(intra_domain_group) - ranks_per_node_tensor = torch.tensor(ranks_per_node_list, dtype=int) - ranks_across_nodes_list = ranks_per_node_tensor.transpose(0, 1).tolist() - inter_node_group, _ = torch.distributed.new_subgroups_by_enumeration( - ranks_across_nodes_list, backend=bootstrap_backend + inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ list(ranks) for ranks in zip(*ranks_per_domain_list) ], + backend=bootstrap_backend, ) - helper = tex.CommOverlapHelper(world_group, intra_node_group, inter_node_group) + helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) else: - self_node_idx = 0 + # TP model on single NVLink domain, no replication, no data-parallelism + mydomain_idx = 0 local_rank = world_rank - intra_node_ranks = list(range(world_size)) + intra_domain_ranks = list(range(world_size)) helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of NVLink domains: {num_nodes}\n", end="", flush=True) + print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks in NVLink domain #{self_node_idx}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", end="", flush=True, ) From d80765e3a8fe39ed4f51c3f9254d8d2f55d79912 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:52:46 +0000 Subject: [PATCH 23/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../csrc/extensions/comm_gemm_overlap.cpp | 8 ++++---- transformer_engine/pytorch/module/base.py | 17 ++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 799e8b3189..d212d13516 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -73,16 +73,16 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, std::optional inter_domain_group) { #ifndef NVTE_UB_WITH_MPI pgs.insert({"world", world_group}); - myrank =pgs["world"]->getRank(); - numranks =pgs["world"]->getSize(); - c10d::ProcessGroup::BackendType backend =pgs["world"]->getBackendType(); + myrank = pgs["world"]->getRank(); + numranks = pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); if (intra_domain_group.has_value()) { // Get local rank on node and number of local ranks NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!",pgs["world"]->getBackendName()); + "group!", pgs["world"]->getBackendName()); pgs.insert({"intra", intra_domain_group.value()}); mylocal = pgs["intra"]->getRank(); numlocal = pgs["intra"]->getSize(); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2cd2165942..8fdc1cc2f6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -248,9 +248,11 @@ def initialize_ub( elif torch.distributed.is_gloo_available(): bootstrap_backend = "gloo" else: - assert bootstrap_backend in ["gloo", "mpi", "nccl"], ( - "Invalid torch.distributed backend for bootstrapping Userbuffers!" - ) + assert bootstrap_backend in [ + "gloo", + "mpi", + "nccl", + ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" assert torch.distributed.is_backend_available(bootstrap_backend), ( f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." @@ -266,10 +268,10 @@ def initialize_ub( # set a network interface via NVTE_UB_SOCKET_IFNAME variable. This can help avoid # issues when different hosts have the same hostname on managed clusters. mydomain = socket.gethostname() - ifname = os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME", - os.getenv("NVTE_UB_SOCKET_IFNAME")) + ifname = os.getenv( + f"{bootstrap_backend.upper()}_SOCKET_IFNAME", os.getenv("NVTE_UB_SOCKET_IFNAME") + ) if ifname is not None: -<<<<<<< HEAD # Make sure the ifname found in the environment is a valid network interface if ifname in [name for _, name in socket.if_nameindex()]: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -290,6 +292,7 @@ def initialize_ub( + "which is known to fail on virtual clusters like Kubernetes. If Userbuffers " + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + "your environment to the correct network interface." + ) warnings.warn(ifname_warning, UserWarning) # Allgather the domain colors across ranks and reduce to a list of unique domains @@ -318,7 +321,7 @@ def initialize_ub( local_rank = torch.distributed.get_rank(intra_domain_group) inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ list(ranks) for ranks in zip(*ranks_per_domain_list) ], + [list(ranks) for ranks in zip(*ranks_per_domain_list)], backend=bootstrap_backend, ) From e5d31cec0aa4fd816b3cdbeafea4d83a47e222be Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Oct 2024 19:49:21 +0000 Subject: [PATCH 24/34] fixed incorrect read of environment variables Signed-off-by: Alp Dener --- build_tools/pytorch.py | 2 +- setup.py | 2 +- transformer_engine/pytorch/module/base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f438d203d6..e7924d8a21 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -81,7 +81,7 @@ def setup_pytorch_extension( continue # Already handled nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) - if os.getenv("NVTE_UB_WITH_MPI") is not None: + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" diff --git a/setup.py b/setup.py index fbf7a37723..3bb2fe6b95 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] - if os.getenv("NVTE_UB_WITH_MPI"): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8fdc1cc2f6..4113d39c87 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -220,7 +220,7 @@ def initialize_ub( which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. """ if not tex.device_supports_multicast(): - assert bool(os.getenv("UB_SKIPMC", "0")), ( + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) From 3c53354bfa333bf2586a2e8b2b1567cbab56e661 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Oct 2024 19:59:25 +0000 Subject: [PATCH 25/34] corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4113d39c87..0ed5507821 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -263,20 +263,19 @@ def initialize_ub( world_size = torch.distributed.get_world_size(world_group) # We have single-node NVLink so we can color based on physical node hostnames. - # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host - # address on that interface instead of the hostname. Otherwise, allow the user to - # set a network interface via NVTE_UB_SOCKET_IFNAME variable. This can help avoid - # issues when different hosts have the same hostname on managed clusters. + # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and + # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on + # the chosen bootstrap backend. mydomain = socket.gethostname() ifname = os.getenv( - f"{bootstrap_backend.upper()}_SOCKET_IFNAME", os.getenv("NVTE_UB_SOCKET_IFNAME") + "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") ) if ifname is not None: # Make sure the ifname found in the environment is a valid network interface if ifname in [name for _, name in socket.if_nameindex()]: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: - hostname = socket.inet_ntoa( + mydomain = socket.inet_ntoa( fcntl.ioctl( s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) )[20:24] @@ -288,10 +287,11 @@ def initialize_ub( else: ifname_warning = ( f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - " attempt to detect ranks on the same node by matching 'socket.gethostname()', " - + "which is known to fail on virtual clusters like Kubernetes. If Userbuffers " - + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " - + "your environment to the correct network interface." + + " attempt to detect ranks on the same node by matching " + + "'socket.gethostname()', which is known to fail on virtual clusters like " + + "Kubernetes. If Userbuffers initialization fails, please set the " + + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " + + "interface." ) warnings.warn(ifname_warning, UserWarning) From 1776282290c9cb9318c605372b0f716d4a4e3346 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 21 Oct 2024 19:24:39 +0000 Subject: [PATCH 26/34] moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count() Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 18 ++---------- .../common/util/cuda_runtime.cpp | 28 +++++++++++++++++++ transformer_engine/common/util/cuda_runtime.h | 8 ++++++ .../common/util/pybind_helper.h | 6 ++-- 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 72a24a2add..420ffffd7d 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -13,6 +13,7 @@ #include "common/common.h" #include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" #include "common/util/logging.h" #include "common/util/system.h" #include "userbuffers/userbuffers.h" @@ -28,18 +29,6 @@ namespace transformer_engine { * Comm+GEMM Overlap Common Core **************************************************************************************************/ -bool device_supports_multicast() { - int dev, supports_multicast; - CUdevice cudev; - - NVTE_CHECK_CUDA(cudaGetDevice(&dev)); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); - - return static_cast(supports_multicast); -} - bool ubuf_built_with_mpi() { #ifdef NVTE_UB_WITH_MPI return true; @@ -82,9 +71,8 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _tp_id = _rank % _tp_size; // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + int sm_count = transformer_engine::cuda::sm_count(); + _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); _atomic_gemm = atomic_gemm; diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 5728ef557a..b722db8570 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -12,6 +12,7 @@ #include "../common.h" #include "../util/cuda_driver.h" #include "../util/system.h" +#include "common/util/cuda_runtime.h" namespace transformer_engine { @@ -80,6 +81,33 @@ int sm_count(int device_id) { return cache[device_id]; } +bool supports_multicast(int device_id) { + static std::vector cache(num_devices(), false); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + int cuda_runtime_version; + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cuda_runtime_version)); + if (cuda_runtime_version >= 12010) { + // On CUDA 12.1+, we can directly query driver for the multicast support property + CUdevice cudev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); + int result; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); + cache[device_id] = static_cast(result); + } else { + // Otherwise, we can assume CUDA Multicast is supported with CUDA 12.0+ and SM arch 9.0+ + cache[device_id] = (cuda_runtime_version >= 12000 && sm_arch() > 90); + } + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +} + const std::string &include_directory(bool required) { static std::string path; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index b6b4c41610..ea1ba84772 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -38,6 +38,14 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief CUDA Multicast support status for device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return CUDA multicast support flag + */ +bool supports_multicast(int device_id = -1); + /* \brief Path to CUDA Toolkit headers * * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a90077f935..432ac815ec 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -12,6 +12,8 @@ #include #include +#include "cuda_runtime.h" + #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType") \ .value("kByte", transformer_engine::DType::kByte) \ @@ -69,8 +71,8 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::device_supports_multicast, \ - py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); From 9dd300cca027247d774508045295ec2e469c6b50 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 19:31:15 +0000 Subject: [PATCH 27/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/cuda_runtime.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index b722db8570..c7d98323c8 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -85,7 +85,7 @@ bool supports_multicast(int device_id) { static std::vector cache(num_devices(), false); static std::vector flags(num_devices()); if (device_id < 0) { - device_id = current_device(); + device_id = current_device(); } NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); auto init = [&]() { From d99734a8e794aedfc43d15469c466bbfce82edc3 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 23 Oct 2024 17:26:43 +0000 Subject: [PATCH 28/34] removed commented out old code and replaced external collective function type defines with aliases Signed-off-by: Alp Dener --- .../common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 3 --- .../common/comm_gemm_overlap/userbuffers/userbuffers.h | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index c6f11005a4..e62dc1549d 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -287,9 +287,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - // NVTE_CHECK_CUDA( - // cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet - // NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 840cb468d2..57e68afce0 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -24,8 +24,8 @@ #define ExtComm const char * #endif -#define ExtAllgatherOp std::function -#define ExtBarrierOp std::function +using ExtAllgatherOp = std::function; +using ExtBarrierOp = std::function; #define NVTE_MAX_REGIONS 16 #define NVTE_MAX_SMS 32 From 94dbe6a834624bc1989552ae3605bd88c0f97e02 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 23 Oct 2024 21:31:03 +0000 Subject: [PATCH 29/34] compile-time CUDA version guard for CUDA Driver Multicast attribute Signed-off-by: Alp Dener --- .../common/util/cuda_runtime.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index c7d98323c8..8d2e852988 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -82,6 +82,9 @@ int sm_count(int device_id) { } bool supports_multicast(int device_id) { +#if CUDART_VERSION >= 12010 + // NOTE: This needs to be guarded at compile time because the + // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. static std::vector cache(num_devices(), false); static std::vector flags(num_devices()); if (device_id < 0) { @@ -89,23 +92,18 @@ bool supports_multicast(int device_id) { } NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); auto init = [&]() { - int cuda_runtime_version; - NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cuda_runtime_version)); - if (cuda_runtime_version >= 12010) { - // On CUDA 12.1+, we can directly query driver for the multicast support property - CUdevice cudev; - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); - int result; - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); - cache[device_id] = static_cast(result); - } else { - // Otherwise, we can assume CUDA Multicast is supported with CUDA 12.0+ and SM arch 9.0+ - cache[device_id] = (cuda_runtime_version >= 12000 && sm_arch() > 90); - } + CUdevice cudev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); + int result; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); + cache[device_id] = static_cast(result); }; std::call_once(flags[device_id], init); return cache[device_id]; +#else + return false; +#endif } const std::string &include_directory(bool required) { From 9c60c00d88a5adbedfff0968549001f48352a75e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 23 Oct 2024 22:22:56 +0000 Subject: [PATCH 30/34] added compile-time CUDA version guards to Multicast code in Userbuffers Signed-off-by: Alp Dener --- .../userbuffers/userbuffers-host.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e62dc1549d..056ac8b11a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -20,7 +20,9 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" #include "common/util/logging.h" +#include "common/util/system.h" #include "ipcsocket.h" #include "userbuffers.h" @@ -201,8 +203,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, (*comm)->asyncblocks = 16; #define NBUF 2 - if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 && - !getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) + +#if CUDART_VERSION >= 12010 + if (!transformer_engine::getenv("UB_SKIPMC") + && transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { + // multicast init only for TP ops (____2 operations) size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); (*comm)->mc_offset = 0; (*comm)->use_mc = 1; @@ -278,11 +283,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, (*comm)->_barrier((*comm)->comm_world); if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); } else { +#endif if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); (*comm)->mc_maxsize = 0; (*comm)->mc_offset = 0; (*comm)->use_mc = 0; +#if CUDART_VERSION >= 12010 } +#endif #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer @@ -462,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->memflags[hndl] = 0; comm->mem_dealloc[hndl] = alloc; +#if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; @@ -577,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * } } else { +#endif if (alloc) { NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); @@ -607,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CHECK_CUDA(cudaDeviceSynchronize()); free(tmp); +#if CUDART_VERSION >= 12010 } +#endif comm->mem_size[hndl] = aligned_size; comm->mem_ptr[hndl] = *gpubuff; From 452b52273e13ed5d87472c1b11e9df8db8cd7a5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 22:23:27 +0000 Subject: [PATCH 31/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 056ac8b11a..6f3eef3d28 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -205,8 +205,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, #define NBUF 2 #if CUDART_VERSION >= 12010 - if (!transformer_engine::getenv("UB_SKIPMC") - && transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { + if (!transformer_engine::getenv("UB_SKIPMC") && + transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { // multicast init only for TP ops (____2 operations) size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); (*comm)->mc_offset = 0; From c1bded43efeb78b75191be039c5a99ee7ec03157 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 24 Oct 2024 17:43:06 +0000 Subject: [PATCH 32/34] condensed UB docs, corrected const violations Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 64 ++++----- .../transformer_engine/comm_gemm_overlap.h | 71 +++++----- transformer_engine/pytorch/csrc/extensions.h | 4 +- transformer_engine/pytorch/module/base.py | 121 +++--------------- 4 files changed, 89 insertions(+), 171 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 420ffffd7d..e157915116 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -150,12 +150,12 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, +void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, - const TensorWrapper &rs_output, cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -198,10 +198,10 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ** Split FPROP GEMM + ReduceScatter */ void CommOverlapBase::atomic_gemm_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, const TensorWrapper &rs_output, cudaStream_t stream_main) { + TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -296,12 +296,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, const TensorWrapper &D, - const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, +void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - const TensorWrapper &rs_output, cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -533,10 +533,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { *phases. */ void CommOverlapP2PBase::atomic_gemm_overlap_ag( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main) { + TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -637,13 +637,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ -void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, - const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, +void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, + TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main) { + TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -817,10 +817,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ** Split ReduceScatter + GEMM using P2P communication */ void CommOverlapP2PBase::atomic_gemm_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main) { + TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -885,10 +885,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( ** Split ReduceScatter + GEMM using P2P communication */ void CommOverlapP2PBase::split_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, const TensorWrapper &pre_gelu_out, - const TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main) { + TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 55622d6d43..a56f41e8eb 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -37,7 +37,7 @@ enum class CommOverlapAlgo { }; class CommOverlapCore { - public: + protected: static inline communicator *_ub_comm{nullptr}; static inline bool _comm_created{false}; @@ -49,10 +49,10 @@ class CommOverlapCore { int _num_comm_sm; int _cga_size; int _use_ce; + int _ub_reg; bool _atomic_gemm{false}; bool _is_p2p{false}; - int _ub_reg; TensorWrapper _ubuf; TensorWrapper _counter; float *_ubuf_scale_inv; @@ -61,6 +61,7 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + public: CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, @@ -81,51 +82,52 @@ class CommOverlapCore { }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { - public: + protected: int _rs_kernel_type; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + public: CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); - ~CommOverlapBase(); + virtual ~CommOverlapBase(); /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, bool grad, + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, - const TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, const TensorWrapper &rs_output, + bool gemm_overlap, TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - const TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &rs_output, cudaStream_t stream_main); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { - public: + protected: bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; @@ -142,6 +144,7 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + public: CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, @@ -149,7 +152,7 @@ class CommOverlapP2PBase : public CommOverlapCore { int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); - ~CommOverlapP2PBase(); + virtual ~CommOverlapP2PBase(); /* ** Split AllGather + AtomicGEMM using P2P communication @@ -158,11 +161,11 @@ class CommOverlapP2PBase : public CommOverlapCore { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main); + TensorWrapper &B_copy, cudaStream_t stream_main); /* ** Split AllGather + GEMM using P2P communication @@ -171,29 +174,29 @@ class CommOverlapP2PBase : public CommOverlapCore { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &B_copy, cudaStream_t stream_main); + TensorWrapper &B_copy, cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - const TensorWrapper &D, const TensorWrapper &bias, - const TensorWrapper &pre_gelu_out, const TensorWrapper &workspace, + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - const TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &rs_output, cudaStream_t stream_main); }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1ea88d3e46..b039bf2d1b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -527,8 +527,8 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(); CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group_holder, - std::optional inter_node_group_holde); + std::optional intra_node_group, + std::optional inter_node_group); ~CommOverlapHelper(); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0ed5507821..b2c62c5f6b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -104,111 +104,26 @@ def initialize_ub( dtype : torch.dtype = torch.bfloat16 non-FP8 data type of the communication buffer when `use_fp8 = False` ub_cfgs: dict = None - Configuration dictionary containing a Userbuffers options for each GEMM layer in a - te.TransformerLayer. Layers that are not configured by the user fall back on the - default options below: + Configuration dictionary with the structure + ``` { - "qkv_fprop": { - "method": "ring_exchange", - "is_reduce_scatter": False, - "num_sm": 1, - "cga_size": 1, - "set_sm_margin": False, - "num_splits": tp_size, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": True, - }, - "qkv_dgrad": { - "method": "bulk", - "is_reduce_scatter": False, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "fp8_buf": True, - }, - "qkv_wgrad": { - "method": "bulk", - "is_reduce_scatter": True, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "fp8_buf": True, - }, - "proj_fprop": { - "method": "pipeline", - "is_reduce_scatter": True, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "num_splits": 4, - "atomic_gemm": False, - "fp8_buf": False, - }, - "proj_dgrad": { - "method": "ring_exchange", - "is_reduce_scatter": False, - "num_sm": 1, - "cga_size": 1, - "set_sm_margin": False, - "num_splits": tp_size, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": True, - }, - "fc1_fprop": { - "method": "ring_exchange", - "is_reduce_scatter": False, - "num_sm": 1, - "cga_size": 1, - "set_sm_margin": False, - "num_splits": tp_size, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": True, - }, - "fc1_dgrad": { - "method": "bulk", - "is_reduce_scatter": False, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "fp8_buf": True, - }, - "fc1_wgrad": { - "method": "bulk", - "is_reduce_scatter": True, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "fp8_buf": False, - }, - "fc2_fprop": { - "method": "pipeline", - "is_reduce_scatter": True, - "num_sm": 16, - "cga_size": 2, - "set_sm_margin": True, - "num_splits": 4, - "atomic_gemm": False, - "fp8_buf": False, - }, - "fc2_dgrad": { - "method": "ring_exchange", - "is_reduce_scatter": False, - "num_sm": 1, - "cga_size": 1, - "set_sm_margin": False, - "num_splits": tp_size, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": True, - }, + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } } + ``` + for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_fprop", "fc2_dgrad"]`. bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are From a5504f17b12b1aa3302e83917b80d2af9e92bdab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 17:44:41 +0000 Subject: [PATCH 33/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 78 ++++++++++--------- .../transformer_engine/comm_gemm_overlap.h | 64 ++++++++------- 2 files changed, 71 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index e157915116..4a73b5ca5c 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -150,12 +150,12 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, CommOverlapType comm_type, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -197,11 +197,12 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::atomic_gemm_overlap_rs( - TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -296,12 +297,12 @@ void CommOverlapBase::atomic_gemm_overlap_rs( /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -532,11 +533,12 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ -void CommOverlapP2PBase::atomic_gemm_overlap_ag( - TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -637,12 +639,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ -void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, - TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, +void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -816,11 +816,13 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::atomic_gemm_overlap_rs( - TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -884,11 +886,11 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::split_overlap_rs( - TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index a56f41e8eb..819e9e33ce 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -100,30 +100,28 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, - TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main); + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + TensorWrapper &rs_output, cudaStream_t stream_main); /* ** Split FPROP GEMM + ReduceScatter */ void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -161,11 +159,11 @@ class CommOverlapP2PBase : public CommOverlapCore { ** in each rank to be in the contiguous memory space after all ring exchange *phases. */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main); + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); /* ** Split AllGather + GEMM using P2P communication @@ -175,28 +173,28 @@ class CommOverlapP2PBase : public CommOverlapCore { *phases. */ void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main); + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); /* ** Split ReduceScatter + GEMM using P2P communication */ void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main); + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); }; // CommOverlapP2PBase } // namespace transformer_engine From 7dd25c5e943d031dc20b469d5e99244d48980200 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 24 Oct 2024 19:29:06 +0000 Subject: [PATCH 34/34] fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels Signed-off-by: Alp Dener --- docs/api/pytorch.rst | 4 ++-- .../comm_gemm_overlap/userbuffers/userbuffers.cu | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 32b60f5bfb..ba4e7db352 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -52,6 +52,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_unpermute -.. autoapifunction:: transformer_engine.initialize_ub +.. autoapifunction:: transformer_engine.pytorch.initialize_ub -.. autoapifunction:: transformer_engine.destroy_ub +.. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index a45d91a387..26843d8107 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) -#if __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 // All MC kernels here template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) { + float4 *mc_ptr, const uint64_t ub_timeout) { int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { + if (clock64() - s > ub_timeout) { UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { + if (clock64() - s > 2ull * ub_timeout) { UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) {} + float4 *mc_ptr, const uint64_t ub_timeout) {} template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset,