From 8e97c8da81082fddae66b9693219780c9b66736f Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 17 Oct 2024 22:48:41 +0800 Subject: [PATCH] [Bugfix] Fix bias for 0-dim tensors in gemm (#1246) * fix bias for 0-dim tensor Signed-off-by: Xin Yao * add check Signed-off-by: Xin Yao * use numel() instead of nullptr Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao --- .../pytorch/csrc/extensions/gemm.cu | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index ba9851e7e8..40b96a057f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -15,10 +15,16 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { - if (D.data_ptr() != nullptr && !accumulate) D.zero_(); - if (bias.data_ptr() != nullptr) bias.zero_(); - if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); + if (A.numel() == 0 || B.numel() == 0) { + if (D.numel() != 0 && !accumulate) D.zero_(); + if (bias.numel() != 0 && grad) { + if (B.numel() == 0) { + bias.zero_(); + } else { + bias.copy_(B.sum(0)); + } + } + if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); return; } @@ -109,10 +115,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int return tensor_wrappers.back().data(); }; for (size_t i = 0; i < A.size(); i++) { - if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) { - if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_(); - if (bias[i].data_ptr() != nullptr) bias[i].zero_(); - if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); + if (A[i].numel() == 0 || B[i].numel() == 0) { + if (D[i].numel() != 0 && !accumulate) D[i].zero_(); + if (bias[i].numel() != 0 && grad) { + if (B[i].numel() == 0) { + bias[i].zero_(); + } else { + bias[i].copy_(B[i].sum(0)); + } + } + if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; } @@ -175,6 +187,8 @@ void te_grouped_gemm_single_output( void* d_i_ptr = reinterpret_cast(D.data_ptr()); for (size_t i = 0; i < A.size(); i++) { if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); + NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); te_A.emplace_back(make_tensor(