Skip to content

Commit

Permalink
fix inf in fused_attention (#41933) (#42032)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored Apr 21, 2022
1 parent efaef31 commit 50fd245
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
Expand Down Expand Up @@ -117,6 +119,18 @@ class FMHARef {
v_ptr = k_ptr + k_size;
}

{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const framework::Tensor*> ins = {&q_tensor};
std::vector<framework::Tensor*> outs = {&q_tensor};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx_, ins,
&outs, functor);
}

// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
Expand All @@ -125,7 +139,7 @@ class FMHARef {
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
Expand Down Expand Up @@ -300,7 +314,9 @@ class FMHARef {
}

T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
// NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
// alpha = 1.0 in backward.
alpha = static_cast<T>(1.0);
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA = CblasTrans;
Expand All @@ -314,6 +330,7 @@ class FMHARef {
qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size,
stride_a, stride_b);
// dx (seq_len * head_dim) = dout * y
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
transA = CblasNoTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
Expand Down

0 comments on commit 50fd245

Please sign in to comment.