diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 613cd4989276d..ad52f4513b7b5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1420,6 +1420,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // result of HasAttr. if (!enable_cache_runtime_context_ && HasAttr(kEnableCacheRuntimeContext)) enable_cache_runtime_context_ = true; + if (this->Type() == "fused_multi_transformer_int8" || this->Type() == "fused_multi_transformer_moe_int8") + enable_cache_runtime_context_ = true; if (!all_kernels_must_compute_runtime_shape_ && HasAttr(kAllKernelsMustComputeRuntimeShape)) all_kernels_must_compute_runtime_shape_ = true; diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index fe82565bc36f3..05c52b850db14 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -24,6 +24,8 @@ register_operators( fused_feedforward_op fused_multi_transformer_op fused_multi_transformer_int8_op + fused_multi_transformer_moe_op + fused_multi_transformer_moe_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op fused_gemm_epilogue_op @@ -121,6 +123,8 @@ if(WITH_GPU OR WITH_ROCM) op_library(fused_attention_op) op_library(fused_multi_transformer_op) op_library(fused_multi_transformer_int8_op) + op_library(fused_multi_transformer_moe_op) + op_library(fused_multi_transformer_moe_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() # resnet_unit needs cudnn 8.0 above diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index ba114df9085fb..ce392e98ba606 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -20,13 +20,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using phi::backends::gpu::GpuLaunchConfig; template class AttnMatmulINT8 { @@ -34,23 +35,26 @@ class AttnMatmulINT8 { AttnMatmulINT8( const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { - auto helper = std::make_shared(m, k, n); - helpers_.emplace_back(helper); + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + helper_ = std::make_unique>(m, k, n, lt_handle); + gpu_config_ = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, m * n, DequantKernelVecSize)); } ~AttnMatmulINT8() {} // This function is used to execute GEMM, with input and output's types are // both T. - void ComputeForward(const framework::Tensor* weight, - const framework::Tensor* input, - framework::Tensor* input_tmp, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* output_tmp, - framework::Tensor* bias_out, + void ComputeForward(const phi::DenseTensor* weight, + const phi::DenseTensor* input, + phi::DenseTensor* input_tmp, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* output_tmp, + phi::DenseTensor* bias_out, const float quant_in_scale, - const framework::Tensor* dequant_out_scale, - const int quant_out_scale_offset, + const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -64,24 +68,26 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), @@ -95,66 +101,72 @@ class AttnMatmulINT8 { // This function is used to execute GEMM, with input and output's types are // both INT8. - void ComputeForwardINT8ToINT8(const framework::Tensor* weight, - framework::Tensor* input, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* bias_out) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output->data(), - dev_ctx_.stream()); + void ComputeForwardINT8ToINT8(const phi::DenseTensor* weight, + phi::DenseTensor* input, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* bias_out, + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } // This function is used to execute GEMM, with input and output's types are // INT8 and T. - void ComputeForwardINT8ToT(const framework::Tensor* weight, + void ComputeForwardINT8ToT(const phi::DenseTensor* weight, const float quant_in_scale, - framework::Tensor* input, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* output_tmp, - framework::Tensor* bias_out, - const framework::Tensor* dequant_out_scale, - const int quant_out_scale_offset) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + phi::DenseTensor* input, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* output_tmp, + phi::DenseTensor* bias_out, + const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); - PADDLE_ENFORCE_EQ(cudaGetLastError(), - cudaSuccess, - platform::errors::Fatal( - "cuda error occured after computing bias. " - "But it does not mean this error is caused by " - "bias computing")); + // PADDLE_ENFORCE_EQ(cudaGetLastError(), + // cudaSuccess, + // platform::errors::Fatal( + // "cuda error occured after computing bias. " + // "But it does not mean this error is caused by " + // "bias computing")); } } // This function is used to execute GEMM, with input and output's types are T // and INT8. - void ComputeForwardTToINT8(const framework::Tensor* weight, + void ComputeForwardTToINT8(const phi::DenseTensor* weight, const float quant_in_scale, - const framework::Tensor* input, - framework::Tensor* input_tmp, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* bias_out, + const phi::DenseTensor* input, + phi::DenseTensor* input_tmp, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* bias_out, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -168,10 +180,12 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } private: @@ -182,8 +196,9 @@ class AttnMatmulINT8 { int k_; // k int compute_bias_; - std::vector> helpers_; + std::unique_ptr> helper_; + std::unique_ptr gpu_config_; }; } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/cublaslt.h b/paddle/fluid/operators/fused/cublaslt.h index b9cc6b56f13ee..b889d3a4d219d 100644 --- a/paddle/fluid/operators/fused/cublaslt.h +++ b/paddle/fluid/operators/fused/cublaslt.h @@ -1,4 +1,5 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,175 +15,796 @@ limitations under the License. */ #pragma once +#include #include #include #include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/dynload/cublasLt.h" +DECLARE_int64(cublaslt_exhaustive_search_times); + namespace dyl = paddle::platform::dynload; namespace paddle { namespace operators { + +#define PADDLE_CUBLASLT_STATUS_CHECK(name) \ + PADDLE_ENFORCE_EQ( \ + status, \ + CUBLAS_STATUS_SUCCESS, \ + platform::errors::External( \ + #name \ + "execution error" \ + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " \ + "information")) + +const int split_k_candidates[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; + +struct CublasLtAlgoSelectorParam { + cublasLtMatmulAlgo_t algo; + int m; + int n; + int k; + int algo_id; + int swizzle; + int custom_option; + int tile; + int split_k_val; + int reduction_scheme; + int stages; + void* workspace; + size_t workspace_size; + float time; +}; + +inline bool compare_algo_time(const CublasLtAlgoSelectorParam& param_a, + const CublasLtAlgoSelectorParam& param_b) { + return (param_a.time < param_b.time); +} +#if CUDA_VERSION >= 11020 +class CublasLtAlgoCache { + public: + static CublasLtAlgoCache& Instance() { + static CublasLtAlgoCache instance(FLAGS_cublaslt_exhaustive_search_times); + return instance; + } + + template + void TestMatmulRun(cublasLtHandle_t handle, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + void* alpha, + void* beta, + const InT* a, + const InT* b, + OutT* c, + CublasLtAlgoSelectorParam& param, // NOLINT + cudaEvent_t& start_event, // NOLINT + cudaEvent_t& stop_event, // NOLINT + cudaStream_t stream) { + cublasStatus_t status; + cublasLtMatmulHeuristicResult_t heuristic_result; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + ¶m.algo, + &heuristic_result); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCheck); + if (status != CUBLAS_STATUS_SUCCESS || + heuristic_result.workspaceSize > param.workspace_size) { + // VLOG(0) << "param.workspace_size is " << param.workspace_size; + param.time = std::numeric_limits::max(); + return; + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); + int repeats = search_times_; + + for (int loop = 0; loop < repeats; loop++) { + status = dyl::cublasLtMatmul(handle, + matmul_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + c, + c_desc, + ¶m.algo, + param.workspace, + param.workspace_size, + stream); + if (status != CUBLAS_STATUS_SUCCESS) { + param.time = std::numeric_limits::max(); + return; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + float time; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventElapsedTime(&time, start_event, stop_event)); + + param.time = time / repeats; + } + + template + cublasLtMatmulAlgo_t* CublasLtAlgoSelect(cublasLtHandle_t handle, + int m, + int n, + int k, + const InT* a, + const InT* b, + OutT* c, + void* alpha, + void* beta, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + cublasComputeType_t compute_type, + cudaDataType_t scale_type, + cudaDataType_t a_type, + cudaDataType_t b_type, + cudaDataType_t c_type, + void* workspace, + size_t workspace_size, + cudaStream_t stream) { + if (search_times_ <= 0) { + VLOG(3) << "Skip CublasLtAlgoSelect process, use default algo instead. " + "If you want to enable CublasLtAlgoSelect, " + "please set FLAGS_cublaslt_exhaustive_search_times > 0"; + return nullptr; + } + + VLOG(1) << "m n k " << m << " " << n << " " << k; + + int64_t seed = 0; + std::hash hash_fn; + + HashMatmulDesc_(matmul_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(a_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(b_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); + + cublasLtMatmulAlgo_t ret; + { + std::lock_guard lock(cache_mutex_); + auto it = map_.find(seed); + if (it != map_.end()) { + VLOG(3) << "CublasLtAlgoSelect Found in cache"; + return &(it->second); + } + } + VLOG(3) << "CublasLtAlgoSelect Not Found in cache"; + + // Get Ids + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + // std::vector algo_ids(requested_algo_count_); + int algo_ids[requested_algo_count_]; // NOLINT + + int num_algo_ids; + status = dyl::cublasLtMatmulAlgoGetIds(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + requested_algo_count_, + algo_ids, + &num_algo_ids); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoGetIds); + + // Traverse all posssible algo combinations + int step = 0; + int limit = 20000; + std::vector params; + + for (int idx = 0; idx < num_algo_ids; idx++) { + cublasLtMatmulAlgo_t algo; + + /* Initialize algo structure with given Algp ID */ + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit + status = dyl::cublasLtMatmulAlgoInit(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + algo_ids[idx], + &algo); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoInit); + + // Query the tiles enums supported by that algo which is used to alloc + // enough space to store it + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute + size_t attr_size = 0; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_TILE_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + int num_tiles = static_cast(attr_size / sizeof(int)); + std::vector tiles(num_tiles == 0 ? 1 : num_tiles); + if (num_tiles == 0) { + tiles[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; + num_tiles = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_TILE_IDS, + tiles.data(), + sizeof(int) * num_tiles, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Query the stages enums supported by that algo (cuda must >= 11.0) + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + int num_stages = static_cast(attr_size / sizeof(int)); + std::vector stages(num_stages == 0 ? 1 : num_stages); + if (num_stages == 0) { + stages[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; + num_stages = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_STAGES_IDS, + stages.data(), + sizeof(int) * num_stages, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Retrieve Other Algo Capabilities attributes + int splitk_support, red_mask, swizzling_max, custom_option_max; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, + &splitk_support, + sizeof(splitk_support), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, + &red_mask, + sizeof(red_mask), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, + &swizzling_max, + sizeof(swizzling_max), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, + &custom_option_max, + sizeof(custom_option_max), + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + /* Loop over the different tiles */ + for (int tile_id = 0; tile_id < num_tiles && step < limit; tile_id++) { + /* Loop over different stages count */ + for (int stage_id = 0; stage_id < num_stages && step < limit; + stage_id++) { + /* Loop over the different custom option if any */ + for (int custom_option = 0; + custom_option <= custom_option_max && step < limit; + custom_option++) { + /* Loop over the CTAs swizzling support */ + for (int k = 0; k <= swizzling_max && step < limit; k++) { + int splir_k_trial = 0; + if (splitk_support) { + splir_k_trial += + sizeof(split_k_candidates) / sizeof(split_k_candidates[0]); + } + + for (int l = 0; (l < (1 + splir_k_trial)) && (step < limit); + l++) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_TILE_ID, + &tiles[tile_id], + sizeof(tiles[tile_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_STAGES_ID, + &stages[stage_id], + sizeof(stages[stage_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &custom_option, + sizeof(custom_option)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); + int split_k_val = 0; + int reduction_scheme = CUBLASLT_REDUCTION_SCHEME_NONE; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_val, + sizeof(split_k_val)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(int)); + if (l > 0) { // Split-K case + split_k_val = split_k_candidates[l - 1]; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_candidates[l - 1], + sizeof(split_k_candidates[l - 1])); + for (reduction_scheme = 1; + reduction_scheme < + static_cast(CUBLASLT_REDUCTION_SCHEME_MASK) && + (step < limit); + reduction_scheme = reduction_scheme << 1) { + if (reduction_scheme & red_mask) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(reduction_scheme)); + PADDLE_CUBLASLT_STATUS_CHECK( + cublasLtMatmulAlgoConfigSetAttribute); + + cublasLtMatmulHeuristicResult_t heurResult; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } // end if + } + } else { + // Prepare algos + cublasLtMatmulHeuristicResult_t heurResult; + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } + } + } + } + } + } + } + cudaEvent_t start_event; + cudaEvent_t stop_event; + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); + + if (step == 0) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "CublasLtAlgoSelect Start testRun " << step << " " + << params.size(); + + for (int i = 0; i < step; i++) { + TestMatmulRun(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + alpha, + beta, + a, + b, + c, + params[i], + start_event, + stop_event, + stream); + } + std::sort(params.begin(), params.end(), compare_algo_time); + + int res_id = 0; + while (params[res_id].time == 0) res_id++; + + if (res_id >= params.size()) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "algo selected"; + + ret = params[res_id].algo; + std::lock_guard lock(cache_mutex_); + auto& algo_in_map = map_[seed]; + algo_in_map = ret; + return &algo_in_map; + } + + ~CublasLtAlgoCache() { + // Serialize map_ to cache file + std::ofstream outfile; + outfile.open(config_filename_, std::ios::out | std::ios::trunc); + outfile << dyl::cublasLtGetCudartVersion() << std::endl; + + for (const auto p : map_) { + outfile << p.first << " "; + for (int i = 0; i < 8; ++i) { + outfile << p.second.data[i] << " "; + } + outfile << std::endl; + } + outfile.close(); + } + + private: + explicit CublasLtAlgoCache(int search_times) : search_times_(search_times) { + // Init map_ from cache file + std::ifstream infile; + infile.open(config_filename_); + if (!infile.is_open()) { + VLOG(3) << "No CublasLtAlgoCache file found"; + return; + } + size_t cublaslt_version, real_cublaslt_version; + int64_t seed = 0; + uint64_t algo_data[8]; + infile >> cublaslt_version; + VLOG(1) << "cublaslt_version " << cublaslt_version; + + if (dyl::cublasLtGetCudartVersion() != cublaslt_version) { + LOG(INFO) << config_filename_ + << " is not compatible with current cublaslt_version " + << real_cublaslt_version; + return; + } + + while (!infile.eof()) { + infile >> seed >> algo_data[0] >> algo_data[1] >> algo_data[2] >> + algo_data[3] >> algo_data[4] >> algo_data[5] >> algo_data[6] >> + algo_data[7]; + + for (int i = 0; i < 8; ++i) { + map_[seed].data[i] = algo_data[i]; + } + } + infile.close(); + } + + std::string config_filename_{"/tmp/paddle_cublaslt_cache"}; + std::unordered_map map_; + int search_times_; + const int requested_algo_count_ = 100; + std::mutex cache_mutex_; + + void HashMatmulDesc_(cublasLtMatmulDesc_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + int trans_a, trans_b; + uint32_t epilogue; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a, + sizeof(trans_a), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_a)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &trans_b, + sizeof(trans_b), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_b)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(epilogue)); + } + + void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + uint32_t dtype; + int32_t batch; + uint64_t row, col; + int64_t ld, batch_offset; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatrixLayoutGetAttribute(desc, + CUBLASLT_MATRIX_LAYOUT_TYPE, + &dtype, + sizeof(dtype), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(dtype)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch, + sizeof(batch), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(row)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(col)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(ld)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_offset, + sizeof(batch_offset), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch_offset)); + } + + void HashValue_(int64_t* seed, + const std::hash& hash_fn, + int64_t value) { + *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } +}; +#endif + +template class CublasLtHelper { public: - CublasLtHelper(int m, int k, int n) - : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) + : alpha_(1), beta_(0), m_(m), k_(k), n_(n), handle_(handle) { cublasStatus_t status; // handle and matmul desc - status = dyl::cublasLtCreate(&handle_); + // status = dyl::cublasLtCreate(&handle_); + // PADDLE_CUBLASLT_STATUS_CHECK(cublasLtCreate); + if (std::is_same::value) { + scale_type_ = CUDA_R_16F; + a_type_ = CUDA_R_16F; + b_type_ = CUDA_R_16F; + c_type_ = CUDA_R_16F; #if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; + compute_type_ = CUDA_R_16F; #else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; + compute_type_ = CUBLAS_COMPUTE_16F; #endif - - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32F; + a_type_ = CUDA_R_32F; + b_type_ = CUDA_R_32F; + c_type_ = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32F; +#else + compute_type_ = CUBLAS_COMPUTE_32F; +#endif + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32I; + a_type_ = CUDA_R_8I; + b_type_ = CUDA_R_8I; + c_type_ = CUDA_R_32I; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32I; +#else + compute_type_ = CUBLAS_COMPUTE_32I; +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "CublasLtHelper just implement for FP16/FP32/INT32.")); + } #if CUBLAS_VER_MAJOR < 11 - status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); + status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, compute_type_); #else status = dyl::cublasLtMatmulDescCreate( - &matmul_desc_, cudaComputeType, CUDA_R_32I); + &matmul_desc_, compute_type_, scale_type_); #endif + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescCreate); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + // Node: Just test for int8 cublasOperation_t op_transpose = CUBLAS_OP_T; status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSA, &op_transpose, sizeof(op_transpose)); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescSetAttribute execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescSetAttribute); // matrix desc - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&b_desc_, a_type_, k, n, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&a_desc_, b_type_, k, m, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&c_desc_, c_type_, n, m, n); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); } ~CublasLtHelper() { - if (handle_) dyl::cublasLtDestroy(handle_); - if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_); - if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_); - if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_); - if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_); + dyl::cublasLtMatmulDescDestroy(matmul_desc_); + dyl::cublasLtMatrixLayoutDestroy(a_desc_); + dyl::cublasLtMatrixLayoutDestroy(b_desc_); + dyl::cublasLtMatrixLayoutDestroy(c_desc_); } - void GEMM(int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, - cudaStream_t stream) { + template + void GEMM(const InT* a_dev, + const InT* b_dev, + OutT* c_dev, + cudaStream_t stream, + void* workspace = nullptr, + size_t workspace_size = 0) { cublasStatus_t status; -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - cublasLtMatmulAlgo_t algo; - int algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 15; - int splitK_val = 0; - int reductionScheme = 0; -#if CUDA_VERSION >= 11000 - int stages = 23; -#endif - -#if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; -#else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; -#endif - - dyl::cublasLtMatmulAlgoInit(handle_, - cudaComputeType, +#if CUDA_VERSION >= 11020 + cublasLtMatmulAlgo_t* algo = + CublasLtAlgoCache::Instance().CublasLtAlgoSelect(handle_, + m_, + n_, + k_, + b_dev, + a_dev, + c_dev, + &alpha_, + &beta_, + matmul_desc_, + b_desc_, + a_desc_, + c_desc_, + compute_type_, + scale_type_, + b_type_, + a_type_, + c_type_, + workspace, + workspace_size, + stream); + + cublasLtMatmulAlgo_t algo_; + if (algo == nullptr) { + int algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 15; + int splitK_val = 0; + int reductionScheme = 0; + int stages = 23; + if (m_ >= 128) { + tile = 20; + stages = 17; + } + dyl::cublasLtMatmulAlgoInit(handle_, + compute_type_, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, - &algo); + &algo_); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo, + &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); + dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + &algo_, + CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, + &(swizzle), + sizeof(swizzle)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); -#if CUDA_VERSION >= 11000 dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif + &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); + algo = &algo_; + } #endif + status = dyl::cublasLtMatmul(handle_, matmul_desc_, &alpha_, - B_dev, - B_desc_, - A_dev, - A_desc_, + b_dev, + b_desc_, + a_dev, + a_desc_, &beta_, - C_dev, - C_desc_, - C_dev, - C_desc_, -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - &algo, + c_dev, + c_desc_, + c_dev, + c_desc_, +#if CUDA_VERSION >= 11020 + algo, + workspace, + workspace_size, #else nullptr, -#endif nullptr, 0, +#endif stream); PADDLE_ENFORCE_EQ( status, @@ -196,11 +818,22 @@ class CublasLtHelper { private: cublasLtHandle_t handle_; cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - int32_t alpha_; - int32_t beta_; + cublasLtMatrixLayout_t a_desc_; + cublasLtMatrixLayout_t b_desc_; + cublasLtMatrixLayout_t c_desc_; + + cudaDataType_t scale_type_; + cudaDataType_t a_type_; + cudaDataType_t b_type_; + cudaDataType_t c_type_; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t compute_type_; +#else + cublasComputeType_t compute_type_; +#endif + + T alpha_; + T beta_; int m_; int k_; @@ -208,4 +841,4 @@ class CublasLtHelper { }; } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 46153d980409d..be102980b4d78 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" @@ -27,7 +28,22 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +template +class PDTraits; + +template <> +class PDTraits { +public: + typedef float DataType; + typedef float data_t; +}; + +template <> +class PDTraits { +public: + typedef half DataType; + typedef paddle::float16 data_t; +}; class AttnDropoutParam { public: @@ -46,7 +62,7 @@ class AttnDropoutParam { bool is_upscale_in_train, bool is_fix_seed, int seed_val, - const Tensor* seed) { + const phi::DenseTensor* seed) { is_test_ = is_test; dropout_implementation_ = dropout_implementation; dropout_prob_ = dropout_prob; @@ -61,9 +77,81 @@ class AttnDropoutParam { bool is_upscale_in_train_; bool is_fix_seed_; int seed_val_; - const Tensor* seed_; + const phi::DenseTensor* seed_; }; +template +__global__ void TransposeRemovingPadding(const T* input_data, + T* output_data, + const int batch_size, + const int num_head, + const int seq_len, + const int head_dim, + const int token_num, + const int elem_cnt, + const int* padding_offset) { + // transpose and remove padding + // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, + // head_dim] + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + const int dim_embed = num_head * head_dim; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / dim_embed; + const int ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int ori_batch_id = ori_token_idx / seq_len; + const int ori_seq_id = ori_token_idx % seq_len; + const int ori_head_id = (linear_index % dim_embed) / head_dim; + const int ori_head_lane = (linear_index % dim_embed) % head_dim; + const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + + ori_head_id * seq_len * head_dim + + ori_seq_id * head_dim + ori_head_lane; + phi::Load(&input_data[ori_idx], &src_vec); + phi::Store(src_vec, &output_data[linear_index]); + } +} + +template +void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, + const T* input_data, + T* output_data, + const int batch_size, + const int num_head, + const int seq_len, + const int head_dim, + const int token_num, + const int* padding_offset) { + // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, + // head_dim] + constexpr int VEC_16B = 16; + const int elem_cnt = token_num * num_head * head_dim; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ( + head_dim % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t block_size = 128; + int32_t grid_size = (pack_num + block_size - 1) / block_size; + TransposeRemovingPadding + <<>>(input_data, + output_data, + batch_size, + num_head, + seq_len, + head_dim, + token_num, + elem_cnt, + padding_offset); +} + template class FMHARef { public: @@ -82,18 +170,18 @@ class FMHARef { ~FMHARef() {} - void ComputeForward(const Tensor& qkv_input_tensor, - const Tensor* cache_kv_tensor, - const Tensor* src_mask_tensor, - Tensor* transpose_2_out_tensor, - Tensor* cache_kv_out_tensor, - Tensor* qk_out_tensor, - Tensor* src_mask_out_tensor, - Tensor* softmax_out_tensor, - Tensor* dropout_mask_out_tensor, - Tensor* dropout_out_tensor, - Tensor* qktv_out_tensor, - Tensor* fmha_out_tensor) { + void ComputeForward(const phi::DenseTensor& qkv_input_tensor, + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + phi::DenseTensor* transpose_2_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor) { // input shape: [bs, seq_len, 3, num_head, head_dim] // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] @@ -104,7 +192,6 @@ class FMHARef { T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); T* fmha_out_data = fmha_out_tensor->data(); auto out_seq_len = seq_len_; @@ -142,8 +229,8 @@ class FMHARef { float alpha = 1.0 / sqrt(head_dim_); auto q_tensor = transpose_2_out_tensor->Slice(0, 1); auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; + std::vector ins = {&q_tensor}; + std::vector outs = {&q_tensor}; phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); } @@ -183,8 +270,8 @@ class FMHARef { seq_len_, dev_ctx_.stream()); } else { - std::vector ins; - std::vector outs; + std::vector ins; + std::vector outs; ins.emplace_back(qk_out_tensor); ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); @@ -220,11 +307,12 @@ class FMHARef { dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_, dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), + static_cast(*softmax_out_tensor), dropout_param_.seed_, dropout_mask_out_tensor, dropout_out_tensor, false); + T* dropout_out_data = dropout_out_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -262,22 +350,210 @@ class FMHARef { dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } - void ComputeBackward(const Tensor& transpose_2_out_tensor, - const Tensor* src_mask_tensor, - const Tensor& softmax_out_tensor, - const Tensor& dropout_mask_out_tensor, - const Tensor& dropout_out_tensor, - const Tensor& qk_out_tensor, - const Tensor& src_mask_out_tensor, - const Tensor& fmha_out_grad_tensor, - Tensor* qktv_out_grad_tensor, - Tensor* dropout_out_grad_tensor, - Tensor* softmax_out_grad_tensor, - Tensor* src_mask_out_grad_tensor, - Tensor* qk_out_grad_tensor, - Tensor* transpose_2_out_grad_tensor, - Tensor* src_mask_grad_tensor, - Tensor* qkv_input_grad_tensor) { + void ComputeForwardWithoutTranspose( + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor* padding_offset_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const int token_num) { + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 3, 1, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + T* qk_out_data = qk_out_tensor->data(); + T* qktv_out_data = qktv_out_tensor->data(); + T* softmax_out_data = softmax_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + auto out_seq_len = seq_len_; + if (cache_kv_tensor) { + // kv [2, bs, num_head, seq_len, head_dim] + phi::funcs::ConcatFunctor concat; + // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] + concat(dev_ctx_, + {*cache_kv_tensor, *kv_transpose_out_tensor}, + 3, + cache_kv_out_tensor); + out_seq_len = cache_kv_out_tensor->dims()[3]; + } + + int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + T* q_ptr = q_transpose_out_tensor->data(); + T* k_ptr = nullptr; + T* v_ptr = nullptr; + + if (cache_kv_tensor) { + int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_ptr = cache_kv_out_tensor->data(); + v_ptr = k_ptr + k_size; + } else { + int64_t k_size = q_size; + k_ptr = kv_transpose_out_tensor->data(); + v_ptr = k_ptr + k_size; + } + + { + // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for + // float16 calculation, INF may appear in QK^T if we do not scale before. + float alpha = 1.0 / sqrt(head_dim_); + auto functor = phi::funcs::ScaleFunctor(alpha); + std::vector ins = {q_transpose_out_tensor}; + std::vector outs = {q_transpose_out_tensor}; + phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); + } + + // q*k^t, batched_gemm + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + auto blas = phi::funcs::GetBlas(dev_ctx_); + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = out_seq_len; + int gemm_k = head_dim_; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + q_ptr, + k_ptr, + beta, + qk_out_data, + gemm_batch_size, + stride_a, + stride_b); + int softmax_axis = -1; + if (src_mask_tensor != nullptr) { + if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { + LaunchFusedSoftmaxMaskKernel(qk_out_data, + src_mask_tensor->data(), + softmax_out_data, + batch_size_, + num_head_, + seq_len_, + dev_ctx_.stream()); + } else { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + phi::funcs::BroadcastKernel( + dev_ctx_, + ins, + &outs, + elewise_add_axis, + phi::funcs::AddFunctor()); + + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); + } + } else { + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); + } + + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = out_seq_len; + alpha = static_cast(1.0); + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + + if (dropout_param_.dropout_prob_) { + T* dropout_out_data = dropout_out_tensor->data(); + DropoutFwGPUKernelDriver( + static_cast(dev_ctx_), + dropout_param_.is_test_, + dropout_param_.dropout_prob_, + dropout_param_.is_upscale_in_train_, + dropout_param_.is_fix_seed_, + dropout_param_.seed_val_, + static_cast(*softmax_out_tensor), + dropout_param_.seed_, + dropout_mask_out_tensor, + dropout_out_tensor, + false); + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + dropout_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } else { + // softmax_out * v, batched_gemm + // output shape: [batch_size, num_heads, seq_len, head_dim] + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + softmax_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + if (!padding_offset_tensor) { + std::vector perm_3 = {0, 2, 1, 3}; + TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + } else { + InvokeTransposeRemovePadding(dev_ctx_, + qktv_out_data, + fmha_out_data, + batch_size_, + num_head_, + seq_len_, + head_dim_, + token_num, + padding_offset_tensor->data()); + } + } + + void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor& softmax_out_tensor, + const phi::DenseTensor& dropout_mask_out_tensor, + const phi::DenseTensor& dropout_out_tensor, + const phi::DenseTensor& qk_out_tensor, + const phi::DenseTensor& src_mask_out_tensor, + const phi::DenseTensor& fmha_out_grad_tensor, + phi::DenseTensor* qktv_out_grad_tensor, + phi::DenseTensor* dropout_out_grad_tensor, + phi::DenseTensor* softmax_out_grad_tensor, + phi::DenseTensor* src_mask_out_grad_tensor, + phi::DenseTensor* qk_out_grad_tensor, + phi::DenseTensor* transpose_2_out_grad_tensor, + phi::DenseTensor* src_mask_grad_tensor, + phi::DenseTensor* qkv_input_grad_tensor) { auto blas = phi::funcs::GetBlas(dev_ctx_); int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; int k_size = q_size; @@ -294,8 +570,6 @@ class FMHARef { const T* softmax_out_data = softmax_out_tensor.data(); T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - const T* dropout_out_data = dropout_out_tensor.data(); - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); T* qktv_out_grad_data = qktv_out_grad_tensor->data(); // transpose bw @@ -317,6 +591,7 @@ class FMHARef { int64_t stride_b = gemm_k * gemm_n; // bw: dy = x^t * dout if (dropout_param_.dropout_prob_) { + const T* dropout_out_data = dropout_out_tensor.data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -354,6 +629,7 @@ class FMHARef { stride_a = gemm_m * gemm_k; stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { + T* dropout_out_grad_data = dropout_out_grad_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -389,7 +665,7 @@ class FMHARef { false, dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, - static_cast(*dropout_out_grad_tensor), + static_cast(*dropout_out_grad_tensor), dropout_mask_out_tensor, softmax_out_grad_tensor, false); @@ -495,3 +771,4 @@ class FMHARef { } // namespace operators } // namespace paddle + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc index 9572a87aba21d..bd84667e21e0a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using Tensor = phi::DenseTensor; class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { private: @@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { CHECK_INPUTS(FFN1Weight); CHECK_INPUTS(FFN2Weight); + // scale + CHECK_INPUTS(QKVOutScale); + CHECK_INPUTS(OutLinearOutScale); + CHECK_INPUTS(FFN1OutScale); + CHECK_INPUTS(FFN2OutScale); + CHECK_OUTPUT(Out); // x: qkv's input [batch_size, seq_len, dim_embed] @@ -93,26 +99,6 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - if (trans_qkvw) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - - } else { - PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], - y_dim[0], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(dim_embed, 3, num_head, dim_head)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - } if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] @@ -129,13 +115,7 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { paddle::platform::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - paddle::platform::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( @@ -143,12 +123,7 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { "head %d, but got %d", trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head - PADDLE_ENFORCE_GT( - c_dim[3], - 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len + PADDLE_ENFORCE_EQ(c_dim[4], trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( @@ -200,9 +175,21 @@ class FusedMultiTransformerINT8OpMaker AddInput("CacheKV", "(optional) The cached KV for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); @@ -232,20 +219,24 @@ class FusedMultiTransformerINT8OpMaker "In order to keep consistent with the PTQ/QAT calculation logic," "QKVOutScale should be max_bound * max_bound / max_range." "Here max_range is per-channel weight scale." - "The shape of QKVOutScale is [num_layers, num_channels]") - .AsDispensable(); + "The shape of QKVOutScale is [num_channels]") + .AsDispensable() + .AsDuplicable(); AddInput("OutLinearOutScale", "OutLinearOutScale is used to dequantize out_linear output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN1OutScale", "FFN1OutScale is used to dequantize ffn1 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN2OutScale", "FFN2OutScale is used to dequantize ffn2 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() @@ -352,6 +343,18 @@ class FusedMultiTransformerINT8OpMaker "quant_min_bound", "(float, default -127.0) the min bound of float type to int type") .SetDefault(-127.0); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); AddComment(R"DOC(fused multi transformer layers op)DOC"); } @@ -366,4 +369,4 @@ REGISTER_OPERATOR( ops::FusedMultiTransformerINT8Op, ops::FusedMultiTransformerINT8OpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 8e200275f8171..d87b3db45cb19 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -14,10 +14,33 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm_int8.h" #include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" +#include "paddle/fluid/operators/fused/layernorm_quant_dequant.h" + +// DECLARE_int32(debug_layer_id); namespace paddle { namespace operators { +template +static void PrintMatrix(const T* mat_d, int num, std::string name) { + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name+".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + if(std::is_same::value) { + ss << static_cast(tmp[i]) << std::endl; + } else { + ss << std::setprecision(8) << tmp[i] << std::endl; + } + } + outfile << ss.str(); + outfile.close(); +} + template class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { public: @@ -25,9 +48,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { using U = LayerNormParamType; auto &dev_ctx = ctx.cuda_device_context(); - auto *time_step = ctx.Input("TimeStep"); + auto *time_step = ctx.Input("TimeStep"); // 0. input - auto *input_x = ctx.Input("X"); + auto *input_x = ctx.Input("X"); const auto input_x_dims = input_x->dims(); int bsz = input_x_dims[0]; int seq_len = input_x_dims[1]; @@ -48,36 +71,87 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // dequant output scales, tensor, size = [num_layers, n], n is gemm output // size - auto *qkv_out_scale = ctx.Input("QKVOutScale"); - auto *out_linear_out_scale = ctx.Input("OutLinearOutScale"); - auto *ffn1_out_scale = ctx.Input("FFN1OutScale"); - auto *ffn2_out_scale = ctx.Input("FFN2OutScale"); + auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); + auto out_linear_out_scales = + ctx.MultiInput("OutLinearOutScale"); + auto ffn1_out_scales = ctx.MultiInput("FFN1OutScale"); + auto ffn2_out_scales = ctx.MultiInput("FFN2OutScale"); + + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + if (sequence_lengths) { + remove_padding = true; + } + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step); + int token_num = 0; + + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } - int qkv_out_scale_n = qkv_out_scale->dims()[1]; - int out_linear_out_scale_n = out_linear_out_scale->dims()[1]; - int ffn1_out_scale_n = ffn1_out_scale->dims()[1]; - int ffn2_out_scale_n = ffn2_out_scale->dims()[1]; + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + sequence_lengths->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({{token_num}}); + // VLOG(0) << "padding_offset_tensor: " << padding_offset_tensor; + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + + if (token_num == 0) { + return; + } + + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); - Tensor ln_mean, ln_var; - ln_mean.Resize({{bsz_seq}}); + AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); auto *ln_mean_data = dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{bsz_seq}}); + ln_var.Resize({{token_num}}); auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); // 2. qkv // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); const bool trans_qkvw = ctx.Attr("trans_qkvw"); const auto qkv_w_dims = qkv_weights[0]->dims(); int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; @@ -89,21 +163,31 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, trans_qkvw, false) AttnMatmulINT8 qkv_compute( - dev_ctx, bsz_seq, output_size, input_size, compute_bias); - Tensor qkv_out; - qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + dev_ctx, token_num, output_size, input_size, false /*compute_bias*/); + phi::DenseTensor qkv_out; + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); auto fmha_compute = FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // auto *time_step = ctx.Input("TimeStep"); + + auto pre_caches = ctx.MultiInput("PreCaches"); + int cache_offset = 0; + if (pre_caches.size() > 0) { + cache_offset = pre_caches[0]->dims()[3]; + } auto out_seq_len = seq_len; if (time_step) { @@ -125,147 +209,203 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } + + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + InitValue(dev_ctx, q_transpose_out_data, q_transpose_out.numel(), static_cast(0.)); + InitValue(dev_ctx, kv_transpose_out_data, kv_transpose_out.numel(), static_cast(0.)); } - Tensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - Tensor softmax_out; - Tensor attn_dropout_mask_out, attn_dropout_out; - Tensor qktv_out, fmha_out; + phi::DenseTensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *src_mask_out_data = + dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + } + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } + + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *softmax_out_data = dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + T *attn_dropout_mask_out_data = nullptr; + T *attn_dropout_data_data = nullptr; qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + fmha_out.Resize({{token_num, num_head, dim_head}}); auto *fmha_out_data = dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) AttnMatmulINT8 out_linear_compute( - dev_ctx, bsz_seq, dim_embed, hidden_size, false); + dev_ctx, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper_just_dequant( + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - Tensor bias_dropout_residual_out, dropout_mask_out; + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + + using LayerNormComputeType = float; + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; T *bias_dropout_residual_out_data = nullptr; if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); bias_dropout_residual_out_data = dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); } - dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + uint8_t *dropout_mask_out_data = nullptr; // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[0]; AttnMatmulINT8 ffn1_linear_compute( - dev_ctx, bsz_seq, dim_ffn, dim_embed, false); - Tensor ffn1_out; - ffn1_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); auto *ffn1_out_data = dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); // 7. ffn act + bias DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); FusedDropoutHelper fused_act_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - Tensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + ffn1_dropout_out.Resize({{token_num, dim_ffn}}); auto *ffn1_dropout_out_data = dev_ctx.Alloc( &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); + uint8_t * ffn1_dropout_mask_data = nullptr; // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); AttnMatmulINT8 ffn2_linear_compute( - dev_ctx, bsz_seq, dim_embed, dim_ffn, false); + dev_ctx, token_num, dim_embed, dim_ffn, false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); FusedDropoutLayerNormHelper ffn2_fused_dropout_dequant_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); // []. init workspace for cublasLt transform - Tensor input_workspace, output_workspace; + phi::DenseTensor input_workspace, output_workspace, cublaslt_workspace; // for input and output transform data is CUBLASLT_ORDER_COL32 format, - int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), + int m_max = token_num, k_max = std::max(dim_embed, dim_ffn), n_max = std::max({output_size, dim_embed, dim_ffn}); - input_workspace.Resize( - {{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}}); + input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); dev_ctx.Alloc(&input_workspace, input_workspace.numel() * sizeof(int8_t)); - output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}}); + + output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); dev_ctx.Alloc(&output_workspace, output_workspace.numel() * sizeof(int32_t)); + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.Alloc(&cublaslt_workspace, + cublaslt_workspace.numel() * sizeof(int8_t)); + // calc - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - Tensor *from_tensor = out; - Tensor tmp_out; - tmp_out.Resize({{bsz, seq_len, dim_embed}}); + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + + // phi::DenseTensor *from_tensor = out; + // phi::DenseTensor tmp_out; + // tmp_out.Resize({{token_num, dim_embed}}); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } + auto *tmp_out_data = dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - auto *x_data = input_x->data(); - Tensor *buf0 = nullptr; - Tensor *buf1 = nullptr; + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); + } + + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; // step0: x --> buf1 // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (pre_layer_norm) { - buf1 = out; - } else { + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. buf0 = &tmp_out; - buf1 = out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + buf1 = out; + } else { + buf0 = &tmp_out; + buf1 = out; + } } for (int i = 0; i < layers; ++i) { @@ -274,6 +414,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i]->data(); auto *ln_bias_data = ln_biases[i]->data(); // TODO(wangxi): can remove mean var in inference + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "fmt in " << *input_x; ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, @@ -292,20 +434,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #endif // step2. qkv - const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha - const Tensor *bias = time_step ? nullptr : qkv_bias; + const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; qkv_compute.ComputeForward(qkv_weights[i], - input_x, + tmp_input_x, &input_workspace, bias, &qkv_out, &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -318,12 +463,16 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); } else { + // if (i == FLAGS_debug_layer_id) { + // VLOG(2) << "qkv in " << input_workspace; + // VLOG(2) << "qkv weight " << *qkv_weights[i]; + // } qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], qkv_in_scale[i], &input_workspace, @@ -331,53 +480,116 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &qkv_out, &output_workspace, &qkv_out, - qkv_out_scale, - i * qkv_out_scale_n); + qkv_out_scales[i], + &cublaslt_workspace); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "qkv out " << qkv_out; // step3. fmha - const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; fmha(dev_ctx, qkv_out, *qkv_bias, *src_mask, + sequence_lengths, + rotary_tensor, + beam_cache_offset, cache_kv_out, &fmha_out, bsz, + beam_size, + // 1, max_seq_len, num_head, dim_head, - time_step->data()[0], + src_mask->dims()[3] - 1, + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - nullptr, - src_mask, - &transpose_out_2, - nullptr, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; + const phi::DenseTensor *pre_cache_kv_tensor = + pre_caches.size() > 0 ? pre_caches[i] : nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + phi::DenseTensor *src_mask_tmp = + cache_offset > 0 ? &src_mask_out : nullptr; + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } // [2, bsz, num_head, max_seq_len, head_dim] int max_seq_len = cache_kv_out->dims()[3]; @@ -387,30 +599,68 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, + sequence_lengths_data, bsz, num_head, - seq_len, + seq_len_tmp, max_seq_len, dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; @@ -424,6 +674,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { nullptr, &output_workspace, nullptr, + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -431,6 +682,12 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ring_id, bsz * seq_len * num_head * dim_head, dev_ctx); + // if (i == FLAGS_debug_layer_id) { + // VLOG(2) << "fmha_out " << fmha_out; + // VLOG(2) << "out_linear weight " << *out_linear_weights[i]; + // VLOG(2) << out_linear_in_scale[i]; + // VLOG(2) << "out_linear_out " << output_workspace; + // } } else { out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out, @@ -440,8 +697,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, out_linear_in_scale[i], - out_linear_out_scale, - i * out_linear_out_scale_n, + out_linear_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -453,31 +710,67 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // step5. ln(residual + dropout(input + bias)) if (pre_layer_norm) { + VLOG(1) << "ffn1 in scale " << ffn1_in_scale[i]; auto *ln_scale_data = ffn_ln_scales[i]->data(); auto *ln_bias_data = ffn_ln_biases[i]->data(); auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace // non-inplace: buf1 -> input_workspace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - out_linear_in_scale[i], - out_linear_out_scale->data(), - i * out_linear_out_scale_n, - ffn1_in_scale[i], - quant_round_type, - quant_max_bound, - quant_min_bound); + // fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // x_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // bias_dropout_residual_out_data, + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // out_linear_in_scale[i], + // out_linear_out_scales[i]->data(), + // ffn1_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + + // phi::DenseTensor ffn_ln_out; + // ffn_ln_out.Resize(input_x->dims()); + // dev_ctx.Alloc(&ffn_ln_out); + + // fused_dropout_layernorm_helper_just_dequant.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // x_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // bias_dropout_residual_out_data, + // dropout_mask_out_data, + // ffn_ln_out.data(), + // ln_mean_data, + // ln_var_data, + // out_linear_in_scale[i], + // out_linear_out_scales[i]->data(), + // ffn1_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // LaunchQuantActKernel(ffn_ln_out.data(), bsz_seq, dim_embed, input_workspace.data(), ffn1_in_scale[i], quant_max_bound, quant_min_bound, dev_ctx.stream()); + + // VLOG(1) << "RIGHT out " << input_workspace; + // DequantSkipLoad load(output_workspace.data(), out_linear_bias_data, x_data, out_linear_out_scales[i]->data(), 0.0f, dim_embed); + DequantSkipLoadAndStoreResidual load(output_workspace.data(), out_linear_bias_data, x_data, + out_linear_out_scales[i]->data(), bias_dropout_residual_out_data, 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, ffn1_in_scale[i], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, token_num, dim_embed, epsilon, ln_mean_data, ln_var_data); + VLOG(1) << "WRONG out " << input_workspace; + } else { auto *ln_scale_data = ln_scales[i]->data(); auto *ln_bias_data = ln_biases[i]->data(); @@ -498,7 +791,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; -#endif +#endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn1_in " << input_workspace; // step6. ffn matmul1 @@ -507,7 +802,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &input_workspace, nullptr, &output_workspace, - nullptr); + nullptr, + &cublaslt_workspace); } else { ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, @@ -517,8 +813,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn1_in_scale[i], - ffn1_out_scale, - i * ffn1_out_scale_n, + ffn1_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -526,6 +822,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step6"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn1 out " << output_workspace; // step7. act bias // TODO(wangxi): remove dropout mask in inference @@ -538,8 +836,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { input_workspace.data(), ffn1_dropout_mask_data, ffn1_in_scale[i], - ffn1_out_scale->data(), - i * ffn1_out_scale_n, + ffn1_out_scales[i]->data(), ffn2_in_scale[i], quant_round_type, quant_max_bound, @@ -556,6 +853,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step7"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn2 in " << input_workspace; // step8. ffn matmul2 if (pre_layer_norm) { @@ -563,7 +862,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &input_workspace, nullptr, &output_workspace, - nullptr); + nullptr, + &cublaslt_workspace); } else { ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out, @@ -573,8 +873,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn2_in_scale[i], - ffn2_out_scale, - i * ffn2_out_scale_n, + ffn2_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -582,6 +882,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.0"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn2 out " << output_workspace; if (pre_layer_norm) { AllReduce(output_workspace, @@ -602,25 +904,57 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i + 1]->data(); auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, - qkv_in_scale[i + 1], - quant_round_type, - quant_max_bound, - quant_min_bound); + // ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf1->data(), + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // qkv_in_scale[i + 1], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + + phi::DenseTensor ln_out; + ln_out.Resize(input_x->dims()); + dev_ctx.Alloc(&ln_out); + + // fused_dropout_layernorm_helper_just_dequant.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf1->data(), + // dropout_mask_out_data, + // ln_out.data(), + // ln_mean_data, + // ln_var_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // qkv_in_scale[i + 1], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // LaunchQuantActKernel(ln_out.data(), bsz_seq, dim_embed, input_workspace.data(), qkv_in_scale[i + 1], quant_max_bound, quant_min_bound, dev_ctx.stream()); + // VLOG(1) << "RIGHT out " << input_workspace; + + // DequantSkipLoad load(output_workspace.data(), ffn2_biases[i]->data(), bias_dropout_residual_out_data, ffn2_out_scales[i]->data(), 0.0f, dim_embed); + DequantSkipLoadAndStoreResidual load(output_workspace.data(), ffn2_biases[i]->data(), bias_dropout_residual_out_data, + ffn2_out_scales[i]->data(), buf1->data(), 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, qkv_in_scale[i + 1], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, token_num, dim_embed, epsilon, ln_mean_data, ln_var_data); + VLOG(1) << "WRONG out " << input_workspace; } else { ffn2_fused_dropout_dequant_helper.ResidualDropoutBias( dev_ctx, @@ -630,8 +964,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { buf1->data(), dropout_mask_out_data, ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, + ffn2_out_scales[i]->data(), 1.0); } } else { @@ -656,6 +989,24 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { if (pre_layer_norm) { x_data = buf1->data(); } + VLOG(2) << "out layer " << i << " " << *buf1; + } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } } } }; @@ -667,4 +1018,4 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_int8, ops::FusedMultiTransformerINT8OpKernel, - ops::FusedMultiTransformerINT8OpKernel); + ops::FusedMultiTransformerINT8OpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc new file mode 100644 index 0000000000000..bc84a4613c56b --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc @@ -0,0 +1,392 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +class FusedMultiTransformerMoeINT8Op : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerMoeINT8Op"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // moe + CHECK_INPUTS(GateWeight); + CHECK_INPUTS(GateBias); + CHECK_INPUTS(ExpertWeight1); + CHECK_INPUTS(ExpertWeight2); + + // scale + CHECK_INPUTS(QKVOutScale); + CHECK_INPUTS(OutLinearOutScale); + CHECK_INPUTS(ExpertWeight1OutScale); + CHECK_INPUTS(ExpertWeight2OutScale); + + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerMoeINT8OpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + + AddInput("GateWeight", "The gate_weights in moe") + .AsDuplicable(); + AddInput("GateBias", "The gate_biases in moe") + .AsDuplicable(); + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("ExpertWeight1", "The expert_weights1 in moe") + .AsDuplicable(); + AddInput("ExpertBias1", "The expert_biases1 in moe") + .AsDuplicable(); + AddInput("ExpertWeight2", "The expert_weights2 in moe") + .AsDuplicable(); + AddInput("ExpertBias2", "The expert_biases2 in moe") + .AsDuplicable(); + + // out scale + AddInput("QKVOutScale", + "QKVOutScale is used to dequantize qkv output tensor." + "In order to keep consistent with the PTQ/QAT calculation logic," + "QKVOutScale should be max_bound * max_bound / max_range." + "Here max_range is per-channel weight scale." + "The shape of QKVOutScale is [num_layers]") + .AsDispensable() + .AsDuplicable(); + AddInput("OutLinearOutScale", + "OutLinearOutScale is used to dequantize out_linear output tensor." + "The definition and shape is the same as QKVOutScale") + .AsDispensable() + .AsDuplicable(); + AddInput("ExpertWeight1OutScale", + "ExpertWeight1OutScale is used to dequantize ffn1 output tensor." + "The definition and shape is num_layers * num_expert") + .AsDispensable() + .AsDuplicable(); + AddInput("ExpertWeight2OutScale", + "ExpertWeight2OutScale is used to dequantize ffn2 output tensor." + "The definition and shape is num_layers * num_expert") + .AsDispensable() + .AsDuplicable(); + + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + + // for moe layer + AddAttr( + "topk", + "gate's topk im moe") + .SetDefault(2); + AddAttr( + "mp_size", + "mp size") + .SetDefault(1); + AddAttr( + "mp_rank", + "mp rank") + .SetDefault(0); + AddAttr( + "num_expert", + "experts num im moe") + .SetDefault(1); + AddAttr( + "world_size", + "world size") + .SetDefault(1); + AddAttr( + "moe_ring_id", + "experts communicate group's ring id") + .SetDefault(1); + AddAttr( + "approximate", + "approximate in expert compute gelu") + .SetDefault(true); + + // int8 add + // AddAttr("num_head", "num_head").SetDefault(0); + // AddAttr("dim_head", "dim_head").SetDefault(0); + // AddAttr("dim_ffn", "dim_ffn").SetDefault(0); + + AddAttr>( + "qkv_in_scale", + "qkv_in_scale is used to quantize qkv input tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of qkv_in_scale should be num_layers, which is equal to " + "QKVW.dims()[0]") + .SetDefault({}); + AddAttr>( + "out_linear_in_scale", + "out_linear_in_scale is used to quantize out_linear input tensor." + "the size of out_linear_in_scale is the same as qkv_in_scale") + .SetDefault({}); + AddAttr>( + "expert_weight1_in_scale", + "expert_weight1_in_scale is used to quantize ffn1 input tensor." + "the size of expert_weight1_in_scale should be num_layers * num_expert") + .SetDefault({}); + AddAttr>( + "expert_weight2_in_scale", + "expert_weight2_in_scale is used to quantize ffn2 input tensor." + "the size of expert_weight2_in_scale should be num_layers * num_expert") + .SetDefault({}); + + AddAttr( + "quant_round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(-2.5)=-3") + .SetDefault(1); + AddAttr( + "quant_max_bound", + "(float, default 127.0) the max bound of float type to int type") + .SetDefault(127.0); + AddAttr( + "quant_min_bound", + "(float, default -127.0) the min bound of float type to int type") + .SetDefault(-127.0); + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_moe_int8, + ops::FusedMultiTransformerMoeINT8Op, + ops::FusedMultiTransformerMoeINT8OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu new file mode 100644 index 0000000000000..4869b14ca10de --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu @@ -0,0 +1,782 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" +#include "paddle/fluid/operators/fused/layernorm_quant_dequant.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +static void PrintMatrix(const T* mat_d, int num, std::string name) { + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name+".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + if(std::is_same::value) { + ss << static_cast(tmp[i]) << std::endl; + } else { + ss << std::setprecision(8) << tmp[i] << std::endl; + } + } + outfile << ss.str(); + outfile.close(); +} + +template +class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + if (bsz_seq == 0) { + return; + } + + // quant input scales, vector, size = num_layers + auto qkv_in_scale = ctx.Attr>("qkv_in_scale"); + auto out_linear_in_scale = + ctx.Attr>("out_linear_in_scale"); + // moe expert scales, vector, size = num_expert * num_layers + auto expert_weight1_in_scale = ctx.Attr>("expert_weight1_in_scale"); + auto expert_weight2_in_scale = ctx.Attr>("expert_weight2_in_scale"); + + // quant round type and bound + auto quant_round_type = ctx.Attr("quant_round_type"); + auto quant_max_bound = ctx.Attr("quant_max_bound"); + auto quant_min_bound = ctx.Attr("quant_min_bound"); + + // dequant output scales, tensor, size = [num_layers, n], n is gemm output + // size + auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); + auto out_linear_out_scales = + ctx.MultiInput("OutLinearOutScale"); + // dequant output scales, tensor, size = [num_layers * num_expert, n], n is gemm output + // size + auto expert_weight1_out_scales = ctx.MultiInput("ExpertWeight1OutScale"); + auto expert_weight2_out_scales = ctx.MultiInput("ExpertWeight2OutScale"); + + auto *sequence_lengths = ctx.Input("SeqLengths"); + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + if (!pre_layer_norm) { + VLOG(0) << "not support post layer norm!"; + return; + } + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + // in type is T, out type is int8_t + auto ln_compute = + AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + AttnMatmulINT8 qkv_compute( + dev_ctx, bsz_seq, output_size, input_size, compute_bias); + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int time_step_cpu = 0; + auto out_seq_len = seq_len; + if (time_step) { + time_step_cpu = src_mask->dims()[3] - 1; + out_seq_len += time_step_cpu; + } + + Tensor transpose_out_2, qk_out; + transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + auto *transpose_out_2_data = + dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz_seq, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear 注意!!这个weight的维度需要transpose!! + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + AttnMatmulINT8 out_linear_compute( + dev_ctx, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + + using LayerNormComputeType = float; + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + bias_dropout_residual_out.Resize({{bsz_seq, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + uint8_t *dropout_mask_out_data = nullptr; + + // 6. moe layer: gate / expert_w & b / some attrs + auto gate_weights = ctx.MultiInput("GateWeight"); + auto gate_biases = ctx.MultiInput("GateBias"); + // weight的维度需要transpose!!!! + auto expert_weights1 = ctx.MultiInput("ExpertWeight1"); + auto expert_biases1 = ctx.MultiInput("ExpertBias1"); + auto expert_weights2 = ctx.MultiInput("ExpertWeight2"); + auto expert_biases2 = ctx.MultiInput("ExpertBias2"); + int dim_feedforward = expert_weights1[0]->dims()[0]; // dim is [dim_feedforward, dim_embed] + int topk = ctx.Attr("topk"); + int mp_size = ctx.Attr("mp_size"); + int mp_rank = ctx.Attr("mp_rank"); + int num_expert = ctx.Attr("num_expert"); + int world_size = ctx.Attr("world_size"); + int moe_ring_id = ctx.Attr("moe_ring_id"); + bool approximate = ctx.Attr("approximate"); + + int tot_expert = world_size * num_expert; + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, dim_embed}}); + dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + fwd_expert_count.Resize({{num_expert}}); + fwd_batch_size.Resize({{1}}); + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&index_select_out, index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&global_gather_out, global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&moe_gather_out, moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, dim_embed}}); + dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + // topk tensor + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + + // []. init workspace for cublasLt transform + Tensor input_workspace, output_workspace, cublaslt_workspace; + // for input and output transform data is CUBLASLT_ORDER_COL32 format, + int m_max = bsz_seq, k_max = std::max({dim_embed, dim_feedforward}), + n_max = std::max({output_size, dim_embed, dim_feedforward}); + // maybe need to change the size of workspace here + + input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); + dev_ctx.Alloc(&input_workspace, + input_workspace.numel() * sizeof(int8_t)); + + output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); + dev_ctx.Alloc(&output_workspace, + output_workspace.numel() * sizeof(int32_t)); + + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.Alloc(&cublaslt_workspace, + cublaslt_workspace.numel() * sizeof(int8_t)); + + // calc + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + Tensor buf0, moe_out; + buf0.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); + moe_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&moe_out, moe_out.numel() * sizeof(T)); + + const T *x_data; + x_data = input_x->data(); + + int layers = qkv_weights.size(); + + for (int i = 0; i < layers; ++i) { +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1, pre layernorm"; +#endif + // step1. layer_norm + if (i == 0) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // layer norm后,对输出做scale,因此输出是int8,在input_workspace中 + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + input_workspace.data(), + ln_mean_data, + ln_var_data, + nullptr, + 0, + qkv_in_scale[i], + quant_round_type, + quant_max_bound, + quant_min_bound); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2, qkv"; +#endif + // step2. qkv + const Tensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + // 输入是int8,input workspace,输出是T,qkv_out + qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], + qkv_in_scale[i], + &input_workspace, // input + bias, + &qkv_out, // out, T + &output_workspace, // out tmp, int32 + &qkv_out, // bias out, T + qkv_out_scales[i], + &cublaslt_workspace); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.1 fmha"; +#endif + // step3. fmha + const Tensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + sequence_lengths, + nullptr, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + beam_size, + max_seq_len, + num_head, + dim_head, + time_step_cpu, + 0, + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); + } else { // not generation + VLOG(0) << "not support!"; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.2 out linear"; +#endif + // T -> int8 + out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], + out_linear_in_scale[i], + &fmha_out, + &input_workspace, // input tmp, 先将输入量化 + nullptr, + &output_workspace, // output, int32 + nullptr, + &cublaslt_workspace, + quant_round_type, + quant_max_bound, + quant_min_bound); + // 输出在output_workspace + AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + // input type is int32, src is T, dst is T + DequantSkipLoadAndStoreResidual load(output_workspace.data(), out_linear_bias_data, x_data, + out_linear_out_scales[i]->data(), bias_dropout_residual_out_data, 0.0f, dim_embed); + // 改为输出先不做scale,输出是fp16,输出到buf0 + AffineQuantStore store(buf0.data(), dim_embed, ln_scale_data, ln_bias_data); + DispatchLayerNorm(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + // moe + // step2 resize and slice ln_out + if (mp_size > 1) { + sliced_inp = buf0.Slice(start, end); + } else { + sliced_inp = buf0; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, gate & topk"; +#endif + // step3 gate & topk + // 这里不做量化 + phi::MatMulAndAdd(dev_ctx, + gate_weights[i], + &sliced_inp, + gate_biases[i], + false, + false, + true, // compute bias + &gate_out, + &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + topk, // scalar + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, number count"; +#endif + phi::NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, all_to_all"; +#endif + if (world_size > 1) { + phi::AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, fwd expert count"; +#endif + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, cumsum"; +#endif + phi::CumsumKernel(dev_ctx, + local_expert_count, + 0, + false, + false, + false, + &lec_cum); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, assign pos"; +#endif + phi::AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, floor divide"; +#endif + if (topk > 1) { + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, tensor copy"; +#endif + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + // all_expert_out.Resize({{fwd_bsz, dim_embed}}); + + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, index select"; +#endif + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + phi::GlobalScatterProcessGroupFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + phi::GlobalScatterFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + + // step 6, Expert Computation +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, Expert Computation"; +#endif + if (fwd_bsz != 0) { + phi::funcs::ConcatFunctor concat; // fp16 + std::vector tmp_expert_out; + int last_index = 0; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + + Tensor expert_in_tmp; // int8_t + expert_in_tmp.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32 }}); + dev_ctx.Alloc(&expert_in_tmp, expert_in_tmp.numel() * sizeof(int8_t)); + + Tensor expert_out1; // int32_t + expert_out1.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32}}); + dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(int32_t)); + + Tensor expert_out2; // T(fp16) + expert_out2.Resize({{cur_expert_count, dim_embed}}); + dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + // act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); maybe int8_t? + // maybe use input_workspace and output workspace? + // dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + + // input is int32_t, output is int8_t + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); // fp16, T + int expert_idx = i * num_expert + idx; + // T to int8_t, matmul, dont compute bias + MatMulTToINT8(dev_ctx, + expert_weights1[expert_idx], + expert_weight1_in_scale[expert_idx], + &tmp_inp, + &expert_in_tmp, + &expert_out1, + cur_expert_count, + dim_feedforward, + dim_embed, + &cublaslt_workspace, // maybe space not enough + quant_round_type, + quant_max_bound, + quant_min_bound); + // act bias, input is int32_t, output is int8_t + fused_act_dropout_helper.DropoutActBias( + dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + expert_in_tmp.data(), + nullptr, + expert_weight1_in_scale[expert_idx], + expert_weight1_out_scales[expert_idx]->data(), + 0, // data offset + expert_weight2_in_scale[expert_idx], + quant_round_type, + quant_max_bound, + quant_min_bound, + approximate); + // linear2, int8_t to T + MatMulINT8ToT(dev_ctx, + expert_weights2[expert_idx], + expert_weight2_in_scale[expert_idx], + &expert_in_tmp, + expert_biases2[expert_idx], + &expert_out2, + &expert_out1, // output_tmp + &expert_out2, + expert_weight2_out_scales[expert_idx], + cur_expert_count, + dim_embed, + dim_feedforward, + true, + &cublaslt_workspace); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + } + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + + // step7. MOEGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, MOEGather"; +#endif + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + phi::GlobalGatherProcessGroupFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + phi::GlobalGatherFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, local_gather or scatter"; +#endif + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + // moe gather out reshape +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, reshape & bmm"; +#endif + moe_gather_out.Resize({{sliced_bsz_seq, topk, dim_embed}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, dim_embed}}); + // step 9, AllGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, AllGather"; +#endif + if (mp_size > 1) { + // all gather + phi::AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + + // step 11, add residual +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, add residual"; +#endif + if (i < layers - 1) { + // add residual & next layer norm & qkv quant + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + // input type is T, src is T, dst is T + DequantSkipLoadAndStoreResidual load(all_gather_out.data(), nullptr, bias_dropout_residual_out_data, + nullptr, moe_out.data(), 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, qkv_in_scale[i + 1], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data); + } else { + // last layer, only add residual, T + phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + } + + x_data = moe_out.data(); + + } // end for layer loop + moe_out.Resize({{bsz, seq_len, dim_embed}}); + *out = moe_out; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_moe_int8, + ops::FusedMultiTransformerMoeINT8OpKernel, + ops::FusedMultiTransformerMoeINT8OpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc new file mode 100644 index 0000000000000..2132d9774eb02 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc @@ -0,0 +1,319 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class FusedMultiTransformerMoeOp : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerMoeOp"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // moe + CHECK_INPUTS(GateWeight); + CHECK_INPUTS(GateBias); + CHECK_INPUTS(ExpertWeight1); + CHECK_INPUTS(ExpertWeight2); + + // out + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const phi::DenseTensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerMoeOpOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + AddInput("GateWeight", "The gate_weights in moe") + .AsDuplicable(); + AddInput("GateBias", "The gate_biases in moe") + .AsDuplicable(); + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("ExpertWeight1", "The expert_weights1 in moe") + .AsDuplicable(); + AddInput("ExpertBias1", "The expert_biases1 in moe") + .AsDuplicable(); + AddInput("ExpertWeight2", "The expert_weights2 in moe") + .AsDuplicable(); + AddInput("ExpertBias2", "The expert_biases2 in moe") + .AsDuplicable(); + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method") + .SetDefault("gelu") + .AddCustomChecker([](const std::string &act_type) { + PADDLE_ENFORCE_EQ( + act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `relu`, `none` activation in " + "FusedMultiTransformer. ")); + }); + + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); + + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + // for moe layer + AddAttr( + "topk", + "gate's topk im moe") + .SetDefault(2); + AddAttr( + "mp_size", + "mp size") + .SetDefault(1); + AddAttr( + "mp_rank", + "mp rank") + .SetDefault(0); + AddAttr( + "num_expert", + "experts num im moe") + .SetDefault(1); + AddAttr( + "world_size", + "world size") + .SetDefault(1); + AddAttr( + "moe_ring_id", + "experts communicate group's ring id") + .SetDefault(1); + AddAttr( + "approximate", + "approximate in expert compute gelu") + .SetDefault(true); + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_moe, + ops::FusedMultiTransformerMoeOp, + ops::FusedMultiTransformerMoeOpOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu new file mode 100644 index 0000000000000..6e6b41dd6ab74 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu @@ -0,0 +1,841 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; +// #define _DEBUG_FUSED_MULTI_TRANSFORMER + +template +class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + if (bsz_seq == 0) { + return; + } + // LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; + const std::string act_method = ctx.Attr("act_method"); + auto *sequence_lengths = ctx.Input("SeqLengths"); // nullptr + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + // LOG(INFO) << "beam_size: " << beam_size; + + auto *out = ctx.Output("Out"); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + if (!pre_layer_norm) { + VLOG(0) << "not support post layer norm!"; + return; + } + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + auto qkv_compute = AttnMatMul(dev_ctx, + false, + trans_qkvw, + bsz_seq, + output_size, + input_size, + compute_bias); + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + auto dropout_implementation = ctx.Attr("dropout_implementation"); + AttnDropoutParam attn_param( + true, dropout_implementation, 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int time_step_cpu = 0; + if (time_step) { + time_step_cpu = src_mask->dims()[3] - 1; + } + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_GT(time_step_cpu, + 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_cpu)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_cpu; + } + + Tensor transpose_out_2, qk_out; + transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + auto *transpose_out_2_data = + dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias), pre layernorm in ffn/moe + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + bias_dropout_residual_out.Resize({{bsz_seq, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + uint8_t *dropout_mask_out_data = nullptr; + + // 6. moe layer: gate / expert_w & b / some attrs + auto gate_weights = ctx.MultiInput("GateWeight"); + auto gate_biases = ctx.MultiInput("GateBias"); + auto expert_weights1 = ctx.MultiInput("ExpertWeight1"); + auto expert_biases1 = ctx.MultiInput("ExpertBias1"); + auto expert_weights2 = ctx.MultiInput("ExpertWeight2"); + auto expert_biases2 = ctx.MultiInput("ExpertBias2"); + int dim_feedforward = expert_weights1[0]->dims()[1]; + // int dim_feedforward = expert_weights1[0]->dims()[2]; // batched gemm + int topk = ctx.Attr("topk"); + int mp_size = ctx.Attr("mp_size"); + int mp_rank = ctx.Attr("mp_rank"); + int num_expert = ctx.Attr("num_expert"); + int world_size = ctx.Attr("world_size"); + int moe_ring_id = ctx.Attr("moe_ring_id"); + bool approximate = ctx.Attr("approximate"); + + int tot_expert = world_size * num_expert; + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, dim_embed}}); + dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + fwd_expert_count.Resize({{num_expert}}); + fwd_batch_size.Resize({{1}}); + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&index_select_out, index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&global_gather_out, global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&moe_gather_out, moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, dim_embed}}); + dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + // topk tensor + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + // for nccl comm + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + // // expert out, alloc max size + // Tensor global_scatter_out; + // global_scatter_out.Resize({{2 * bsz_seq, dim_embed}}); + // dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + // Tensor expert_out1, expert_out2, all_expert_out; + // expert_out1.Resize({{2 * bsz_seq, dim_feedforward}}); + // // act_bias_out.Resize({{bsz_seq, dim_feedforward}}); + // expert_out2.Resize({{2 * bsz_seq, dim_embed}}); + // all_expert_out.Resize({{2 * bsz_seq, dim_embed}}); + // dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + // // dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + // dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + // dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + Tensor buf0, moe_out; + buf0.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); + moe_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&moe_out, moe_out.numel() * sizeof(T)); + + const T *x_data; + x_data = input_x->data(); + + int layers = qkv_weights.size(); + + for (int i = 0; i < layers; ++i) { +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1, pre layernorm"; +#endif + // step1. layer_norm, only layer 0 + if (i == 0) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf0.data(), + ln_mean_data, + ln_var_data); + } + // auto *ln_scale_data = ln_scales[i]->data(); + // auto *ln_bias_data = ln_biases[i]->data(); + // // TODO(wangxi): can remove mean var in inference + // ln_compute.ComputeForward(x_data, + // ln_scale_data, + // ln_bias_data, + // buf0.data(), + // ln_mean_data, + // ln_var_data); + + // step2. qkv +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2, qkv"; +#endif + const Tensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + qkv_compute.ComputeForward( + qkv_weights[i], &buf0, bias, &qkv_out, &qkv_out); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.1 fmha"; +#endif + // step3. fmha + const Tensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + sequence_lengths, + nullptr, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + beam_size, + max_seq_len, + num_head, + dim_head, + time_step_cpu, + 0, + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation encoder stage + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); + } else { // not generation + VLOG(0) << "not support!"; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.2 out linear"; +#endif + // 输出到buf0 + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, &buf0, nullptr); + AllReduce(buf0, ring_id, buf0.numel(), dev_ctx); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)),在MHA里的 + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // pre layer norm : bias_dropout_residual_out is residual + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0.data(), + x_data, // residual, moe out + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf0.data(), // output to buf0 + ln_mean_data, + ln_var_data); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + // moe + // step2 resize and slice ln_out + if (mp_size > 1) { + sliced_inp = buf0.Slice(start, end); + } else { + sliced_inp = buf0; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, gate & topk"; +#endif + // step3 gate & topk + phi::MatMulAndAdd(dev_ctx, + gate_weights[i], + &sliced_inp, + gate_biases[i], + false, + false, + true, // compute bias + &gate_out, + &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + topk, // scalar + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, number count"; +#endif + phi::NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, all_to_all"; +#endif + if (world_size > 1) { + phi::AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, fwd expert count"; +#endif + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, cumsum"; +#endif + phi::CumsumKernel(dev_ctx, + local_expert_count, + 0, + false, + false, + false, + &lec_cum); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, assign pos"; +#endif + phi::AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, floor divide"; +#endif + if (topk > 1) { + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, tensor copy"; +#endif + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + // all_expert_out.Resize({{fwd_bsz, dim_embed}}); + + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, index select"; +#endif + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + phi::GlobalScatterProcessGroupFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + phi::GlobalScatterFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + false, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + + // step 6, Expert Computation +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, Expert Computation"; +#endif + if (fwd_bsz != 0) { + phi::funcs::ConcatFunctor concat; + std::vector tmp_expert_out; + // if (time_step) { + // // decoder, use batched gemm + // Tensor expert_out1, expert_out2; + // expert_out1.Resize({{num_expert, fwd_bsz, dim_feedforward}}); + // expert_out2.Resize({{num_expert, fwd_bsz, dim_embed}}); + // dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + // dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + + // BatchedMatMulAndAdd(dev_ctx, + // expert_weights1[i], + // &global_scatter_out, + // expert_biases1[i], // bias + // false, + // false, + // true, // compute bias + // true, // is linear1 + // &expert_out1, + // &expert_out1); // bias out + // phi::GeluKernel(dev_ctx, expert_out1, approximate, &expert_out1); + // BatchedMatMulAndAdd(dev_ctx, + // expert_weights2[i], + // &expert_out1, // input + // expert_biases2[i], + // false, + // false, + // true, // compute bias + // false, // is linear1 + // &expert_out2, + // &expert_out2); + // int last_index = 0; + // for (int idx = 0; idx < num_expert; idx++) { + // int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + // if (cur_expert_count <= 0) { + // continue; + // } + // int end = cur_expert_count + last_index; + // // expert_out2 slice + // Tensor tmp_sliced; + // phi::SliceCompute(dev_ctx, + // expert_out2, + // {0, 1}, + // {idx, last_index}, + // {idx + 1, end}, + // {1, 1}, + // {}, + // &tmp_sliced); + // tmp_sliced.Resize({{cur_expert_count, dim_embed}}); // maybe dont need resize + // tmp_expert_out.emplace_back(tmp_sliced); + // last_index = end; + // } + // } else { + + // encoder, use matmul + int last_index = 0; + // std::vector tmp_expert_out; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + + Tensor expert_out1, expert_out2, act_bias_out; + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + expert_out2.Resize({{cur_expert_count, dim_embed}}); + act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + int expert_idx = i * num_expert + idx; + + // linear1 matmul + // VLOG(0) << "moe, Expert Computation, linear1 mul"; + phi::MatMulAndAdd(dev_ctx, + expert_weights1[expert_idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // MatMulAndAdd(dev_ctx, + // expert_weights1[i]->data() + idx * dim_embed * dim_feedforward, + // tmp_inp.data(), + // nullptr, // bias + // cur_expert_count, + // dim_feedforward, + // dim_embed, + // false, + // false, + // false, // dont compute bias + // expert_out1.data(), + // nullptr); + + // bias gelu + // VLOG(0) << "moe, Expert Computation, add bias & gelu"; + // inplace + fused_act_dropout_helper.DropoutActBias(dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + act_bias_out.data(), + nullptr, + 1.0, + nullptr, + 0, + 1.0, + 1, + 127.0, + -127.0, + approximate); + // fused_act_dropout_helper.DropoutActBias(dev_ctx, + // expert_out1.data(), + // expert_biases1[i]->data() + idx * dim_feedforward, + // "gelu", + // act_bias_out.data(), + // nullptr, + // 1.0, + // nullptr, + // 0, + // 1.0, + // 1, + // 127.0, + // -127.0, + // approximate); + + // linear2 matmul & add + // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; + phi::MatMulAndAdd(dev_ctx, + expert_weights2[expert_idx], + &act_bias_out, + expert_biases2[expert_idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + // MatMulAndAdd(dev_ctx, + // expert_weights2[i]->data() + idx * dim_embed * dim_feedforward, + // act_bias_out.data(), + // expert_biases2[i]->data() + idx * dim_embed, + // cur_expert_count, + // dim_embed, + // dim_feedforward, + // false, + // false, + // true, // compute bias + // expert_out2.data(), + // expert_out2.data()); + // Addmm(dev_ctx, + // *expert_biases2[expert_idx], + // act_bias_out, + // *expert_weights2[expert_idx], + // 1.0, + // 1.0, + // &expert_out2); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + // } + } + // at last, concat all expert out + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + + // step7. MOEGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, MOEGather"; +#endif + if (world_size > 1) { + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + phi::GlobalGatherProcessGroupFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + phi::GlobalGatherFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + false, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, local_gather or scatter"; +#endif + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + // moe gather out reshape +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, reshape & bmm"; +#endif + moe_gather_out.Resize({{sliced_bsz_seq, topk, dim_embed}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, dim_embed}}); + // step 9, AllGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, AllGather"; +#endif + if (mp_size > 1) { + // all gather + phi::AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + + // step 11, add residual +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, add residual"; +#endif + if (i < layers - 1) { + // add residual & next layer norm + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + all_gather_out.data(), // src + bias_dropout_residual_out_data, // residual + nullptr, // bias + ln_scale_data, + ln_bias_data, + moe_out.data(), // add out, next layer real input, for residual + dropout_mask_out_data, + buf0.data(), // out, after layernorm + ln_mean_data, + ln_var_data); + } else { + // last layer, only add residual + phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + } + + // phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + x_data = moe_out.data(); + + } // layers loop end + moe_out.Resize({{bsz, seq_len, dim_embed}}); + *out = moe_out; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_moe, + ops::FusedMultiTransformerMoeOpKernel, + ops::FusedMultiTransformerMoeOpKernel); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h new file mode 100644 index 0000000000000..01a5e344ecc54 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h @@ -0,0 +1,273 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// This file has been adapted from FasterTransformer file: +// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu +// We add License in the head. + +#pragma once +// #include +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" +#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h" +#include "paddle/fluid/operators/fused/attn_gemm_int8.h" +// #include "paddle/phi/kernels/funcs/eigen/common.h" +// #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +// #include "paddle/phi/kernels/impl/slice_kernel_impl.h" +// #include "paddle/phi/kernels/gelu_kernel.h" +// #include "paddle/fluid/operators/fused/attn_bias_add.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = Tensor; + +// template +// void BatchedMatMulAndAdd(const phi::GPUContext& dev_ctx, +// const Tensor* weight, +// const Tensor* input, +// const Tensor* bias, +// bool istransA, +// bool istransB, +// bool compute_bias, +// bool is_linear1, +// Tensor* output, +// Tensor* bias_out) { +// // Note: for blas.BatchedGEMM API in Paddle, it treats all inputs as row-major. +// // for input [bsz_seqlen, dim_embed] * expert_weight [expert_num, dim_embed, dim_feedforward] +// CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; +// CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; +// T alpha = static_cast(1.0); +// T beta = static_cast(0.0); + +// // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) +// auto blas = phi::funcs::GetBlas(dev_ctx); +// const int x_ndim = input->dims().size(); +// auto M = input->dims()[x_ndim - 2]; +// auto N = weight->dims()[2]; +// auto K = input->dims()[x_ndim - 1]; +// auto out_batch_size = weight->dims()[0]; +// int64_t strideA = is_linear1 ? 0 : M * K; +// blas.BatchedGEMM(transA, +// transB, +// M, +// N, +// K, +// alpha, +// input->data(), +// weight->data(), +// beta, +// output->data(), +// out_batch_size, +// strideA, +// K * N); +// if (compute_bias) { +// // bias_out = output + bias +// std::vector ins = {output, bias}; +// std::vector outs = {bias_out}; +// phi::funcs::BroadcastKernel( +// dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); +// } +// } + +// template +// void MatMulAndAdd(const phi::GPUContext& dev_ctx, +// const T* weight, // input & output params is data pointer +// const T* input, +// const T* bias, +// int M, +// int N, +// int K, +// bool istransA, +// bool istransB, +// bool compute_bias, +// T* output, +// T* bias_out) { +// // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. +// // here: (transa, transb): nt, input * weight. +// CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; +// CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; +// T alpha = static_cast(1.0); +// T beta = static_cast(0.0); +// // input->dims()[0], // M +// // weight->dims()[1], // N +// // input->dims()[1], // K +// // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) +// auto blas = phi::funcs::GetBlas(dev_ctx); +// blas.GEMM(transA, +// transB, +// M, +// N, +// K, +// alpha, +// input, +// weight, +// beta, +// output); +// if (compute_bias) { +// // bias_out = output + bias +// // std::vector ins = {output, bias}; +// // std::vector outs = {bias_out}; +// // phi::funcs::BroadcastKernel( +// // dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); +// LaunchBiasAddFwKernel(dev_ctx, +// M, +// N, +// output, +// bias, +// bias_out); +// } +// } + +// template +// using PhiEigenTensor = phi::EigenTensor; + +// using Array1 = Eigen::DSizes; +// using Array2 = Eigen::DSizes; + +// template +// void Addmm(const phi::GPUContext& dev_ctx, +// const Tensor& input, // bias +// const Tensor& x, // input +// const Tensor& y, // weight +// float alpha, +// float beta, +// Tensor* out) { +// auto input_dims = input.dims(); +// auto x_dims = x.dims(); +// auto y_dims = y.dims(); + +// Tensor input_2d(input); +// if (input.dims().size() == 1) { +// input_dims = {1, input.dims()[0]}; +// input_2d.Resize(input_dims); +// } + +// // dev_ctx.template Alloc(out); +// auto blas = phi::funcs::GetBlas(dev_ctx); + +// // calc broadcast dim +// Array2 bcast_dims; +// bcast_dims[0] = x_dims[0] / input_dims[0]; +// bcast_dims[1] = y_dims[1] / input_dims[1]; +// VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]"; +// // broadcast using eigen +// const Tensor& const_ref_input = input_2d; +// auto eigen_input = PhiEigenTensor::From(const_ref_input); +// auto eigen_out = PhiEigenTensor::From(*out); +// auto& place = *dev_ctx.eigen_device(); +// phi::funcs::EigenBroadcast, T, 2>::Eval( +// place, eigen_out, eigen_input, bcast_dims); + +// T t_alpha = static_cast(alpha); +// T t_beta = static_cast(beta); +// blas.GEMM(false, +// false, +// x_dims[0], +// y_dims[1], +// x_dims[1], +// t_alpha, +// x.data(), +// x_dims[1], +// y.data(), +// y_dims[1], +// t_beta, +// out->data(), +// y_dims[1]); +// } + +using phi::backends::gpu::GpuLaunchConfig; +// This function is used to execute GEMM, with input and output's types are T +// and INT8. +template +void MatMulTToINT8(const phi::GPUContext& dev_ctx, + const Tensor* weight, + const float quant_in_scale, + const Tensor* input, + Tensor* input_tmp, + Tensor* output, + int m, + int n, + int k, + Tensor* workspace = nullptr, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + auto helper = std::make_unique>(m, k, n, lt_handle); + quantize_kernel_launcher(input->data(), + input_tmp->data(), + quant_in_scale, + m, + k, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx.stream()); + + helper->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx.stream(), + (void*)workspace->data(), + workspace->numel()); +} + +template +void MatMulINT8ToT(const phi::GPUContext& dev_ctx, + const Tensor* weight, + const float quant_in_scale, + const Tensor* input, + const Tensor* bias, + Tensor* output, + Tensor* output_tmp, + Tensor* bias_out, + const Tensor* dequant_out_scale, + int m, + int n, + int k, + bool compute_bias, + Tensor* workspace = nullptr) { + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + auto helper = std::make_unique>(m, k, n, lt_handle); + auto gpu_config = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, m * n, DequantKernelVecSize)); + + helper->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx.stream(), + (void*)workspace->data(), + workspace->numel()); + + dequantize_kernel_launcher(output_tmp->data(), + output->data(), + m, + n, + dev_ctx.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scale->data()); + + if (compute_bias) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + } +} + +} // operators +} // paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 86de140b9cde8..3edb1a733e29d 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -21,8 +21,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; - class FusedMultiTransformerOp : public framework::OperatorWithKernel { private: static constexpr const char *OpName = "FusedMultiTransformerOp"; @@ -93,27 +91,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - if (trans_qkvw) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - - } else { - PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], - y_dim[0], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(dim_embed, 3, num_head, dim_head)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - } - if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] const auto &c_dims = ctx->GetInputsDim("CacheKV"); @@ -129,13 +106,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { paddle::platform::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - paddle::platform::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size PADDLE_ENFORCE_EQ(c_dim[2], trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( @@ -143,12 +113,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "head %d, but got %d", trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head - PADDLE_ENFORCE_GT( - c_dim[3], - 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len PADDLE_ENFORCE_EQ(c_dim[4], trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( @@ -170,7 +134,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, - const Tensor &tensor, + const phi::DenseTensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; @@ -199,16 +163,27 @@ class FusedMultiTransformerOpOpMaker AddInput("CacheKV", "(optional) The cached KV for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); AddInput("OutLinearBias", "The out_linear bias tensor.") .AsDispensable() .AsDuplicable(); - AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") .AsDuplicable(); AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") @@ -223,17 +198,39 @@ class FusedMultiTransformerOpOpMaker AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") .AsDispensable() .AsDuplicable(); - + AddInput("QKVWScale", "QKVWScale") + .AsDispensable() + .AsDuplicable(); + AddInput("OutLinearWScale", "OutLinearWScale") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN1WeightScale", "FFN1WeightScale") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN2WeightScale", "FFN2WeightScale") + .AsDispensable() + .AsDuplicable(); AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() .AsDuplicable(); AddOutput("Out", "Result after multi ."); - AddAttr("pre_layer_norm", "if true, the attention op uses pre_layer_norm architecure, " "else, uses post_layer_norm architecuture. " "[default true].") .SetDefault(true); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); AddAttr("epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) @@ -272,7 +269,17 @@ class FusedMultiTransformerOpOpMaker "dropout_implementation can only be downgrade_in_infer or " "upscale_in_train")); }); - AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr("act_method", "act_method") + .SetDefault("gelu") + .AddCustomChecker([](const std::string &act_type) { + PADDLE_ENFORCE_EQ( + act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `relu`, `none` activation in " + "FusedMultiTransformer. ")); + }); + AddAttr( "trans_qkvw", "Whether the weights of qkv should be transposed. If true," @@ -281,11 +288,14 @@ class FusedMultiTransformerOpOpMaker "[dim_embed, 3, num_head, dim_head]") .SetDefault(true); + AddAttr("quant_weight","Whether do weight quant") + .SetDefault(false); + AddAttr( "ring_id", "ring id for tensor model parallel. distributed training and inference") .SetDefault(-1); - + AddComment(R"DOC(fused multi transformer layers op)DOC"); } }; @@ -309,3 +319,4 @@ REGISTER_OP_VERSION(fused_multi_transformer) "trans_qkvw", "A flag to indicate whether to transpose for weights of qkv.", true)); + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 5cf22885aabba..d0f2c7ba08fe9 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -21,34 +21,99 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { using U = LayerNormParamType; auto &dev_ctx = ctx.cuda_device_context(); - auto *time_step = ctx.Input("TimeStep"); + auto *time_step = ctx.Input("TimeStep"); // 0. input - auto *input_x = ctx.Input("X"); + auto *input_x = ctx.Input("X"); const auto input_x_dims = input_x->dims(); int bsz = input_x_dims[0]; int seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int bsz_seq = bsz * seq_len; + // LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; + const std::string act_method = ctx.Attr("act_method"); + bool use_glu = (act_method == "geglu"); + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + if (sequence_lengths) { + remove_padding = true; + } + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + // LOG(INFO) << "beam_size: " << beam_size; + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step); + int token_num = 0; + + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // Init out + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + sequence_lengths->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({{token_num}}); + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + + if (token_num == 0) { + return; + } + + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + // whether do weight only quant // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); - Tensor ln_mean, ln_var; - ln_mean.Resize({{bsz_seq}}); + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); auto *ln_mean_data = dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{bsz_seq}}); + ln_var.Resize({{token_num}}); auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); // 2. qkv // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); const bool trans_qkvw = ctx.Attr("trans_qkvw"); const auto qkv_w_dims = qkv_weights[0]->dims(); int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; @@ -59,71 +124,95 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we + // set compute_bias as false. auto qkv_compute = AttnMatMul(dev_ctx, false, trans_qkvw, - bsz_seq, + token_num, output_size, input_size, - compute_bias); - Tensor qkv_out; - qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + /*compute_bias=*/false); + phi::DenseTensor qkv_out; + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); auto fmha_compute = FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int cache_offset = 0; + + int time_step_cpu = 0; + if (time_step) { + // VLOG(0) << "time_step: " << *time_step; + time_step_cpu = src_mask->dims()[3] - 1; + // VLOG(0) << "time_step_cpu: " << time_step_cpu; + } auto out_seq_len = seq_len; if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, + PADDLE_ENFORCE_GT(time_step_cpu, 0, platform::errors::PreconditionNotMet( "The value of time_step must > 0, but now is %d", - time_step_value)); + time_step_cpu)); PADDLE_ENFORCE_EQ( seq_len, 1, platform::errors::PreconditionNotMet( "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); - out_seq_len += time_step_value; + out_seq_len += time_step_cpu; + } else { + out_seq_len += cache_offset; + } + + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + InitValue(dev_ctx, + q_transpose_out_data, + q_transpose_out.numel(), + static_cast(0.)); + InitValue(dev_ctx, + kv_transpose_out_data, + kv_transpose_out.numel(), + static_cast(0.)); } - Tensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - Tensor softmax_out; - Tensor attn_dropout_mask_out, attn_dropout_out; - Tensor qktv_out, fmha_out; + phi::DenseTensor src_mask_out; + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *softmax_out_data = dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); @@ -132,97 +221,113 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) auto out_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); + dev_ctx, false, false, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - Tensor bias_dropout_residual_out, dropout_mask_out; + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; T *bias_dropout_residual_out_data = nullptr; if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); bias_dropout_residual_out_data = dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); } - dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + uint8_t *dropout_mask_out_data = nullptr; // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_weights_scales = + ctx.MultiInput("FFN1WeightScale"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[1]; + FFNGluHelper ffn1_glu_helper( + dev_ctx, act_method, token_num, dim_ffn / 2, dim_ffn, dim_embed); auto ffn1_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false); - Tensor ffn1_out; - ffn1_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); auto *ffn1_out_data = dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); // 7. ffn act + bias DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - Tensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + int tmp_dim_ffn = dim_ffn; + if (use_glu) tmp_dim_ffn /= 2; + int8_t *ffn1_dropout_mask_data = nullptr; + ffn1_dropout_out.Resize({{token_num, tmp_dim_ffn}}); auto *ffn1_dropout_out_data = dev_ctx.Alloc( &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); auto ffn2_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); + dev_ctx, false, false, token_num, dim_embed, tmp_dim_ffn, false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - Tensor *from_tensor = out; - Tensor tmp_out; - tmp_out.Resize({{bsz, seq_len, dim_embed}}); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } auto *tmp_out_data = dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - auto *x_data = input_x->data(); - Tensor *buf0 = nullptr; - Tensor *buf1 = nullptr; + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); + } + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; // step0: x --> buf1 // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } - } else { - buf0 = &tmp_out; - buf1 = out; } for (int i = 0; i < layers; ++i) { @@ -238,28 +343,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif // step2. qkv - const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha - const Tensor *bias = time_step ? nullptr : qkv_bias; + const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; qkv_compute.ComputeForward( - qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); } else { qkv_compute.ComputeForward( qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif // step3. fmha - const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] @@ -268,35 +371,89 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { qkv_out, *qkv_bias, *src_mask, + sequence_lengths, + rotary_tensor, + beam_cache_offset, cache_kv_out, &fmha_out, bsz, + beam_size, max_seq_len, num_head, dim_head, - time_step->data()[0], + time_step_cpu, + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage - // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - nullptr, - src_mask, - &transpose_out_2, - nullptr, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; + const phi::DenseTensor *pre_cache_kv_tensor = nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = nullptr; + phi::DenseTensor *src_mask_tmp = nullptr; + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } // [2, bsz, num_head, max_seq_len, head_dim] int max_seq_len = cache_kv_out->dims()[3]; @@ -306,35 +463,72 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, + sequence_lengths_data, bsz, num_head, - seq_len, + seq_len_tmp, max_seq_len, dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; #endif - if (pre_layer_norm) { out_linear_compute.ComputeForward( out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); @@ -390,25 +584,31 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { #endif // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward( - ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + if (use_glu) { + ffn1_glu_helper.Compute(buf1, + ffn1_weights[i], + ffn1_biases[i], + &ffn1_out, + &ffn1_dropout_out); + } else { + ffn1_linear_compute.ComputeForward( + ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step6"; #endif // step7. act bias // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias(dev_ctx, - ffn1_out_data, - ffn1_biases[i]->data(), - "gelu", - ffn1_dropout_out_data, - ffn1_dropout_mask_data); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif - - // step8. ffn matmul2 + if (!use_glu) { + fused_act_dropout_helper.DropoutActBias(dev_ctx, + ffn1_out_data, + ffn1_biases[i]->data(), + act_method, + ffn1_dropout_out_data, + ffn1_dropout_mask_data); + } + // step8. ffn2 matmul if (pre_layer_norm) { ffn2_linear_compute.ComputeForward( ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); @@ -480,6 +680,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { std::swap(buf0, buf1); } } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } + } } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.h index 761a31ce094d1..94865f4415413 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.h @@ -1,12 +1,9 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +13,8 @@ limitations under the License. */ // https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu // We add License in the head. +#pragma once + #include #include @@ -32,22 +31,37 @@ limitations under the License. */ #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/kernels/funcs/math_function.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +// #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif +// #endif + +#include +#include + +DECLARE_bool(gemm_use_half_precision_compute_type); namespace paddle { namespace operators { -using Tensor = framework::Tensor; +template +void print_tensor(const T *t, int size, const char *name){ + using namespace std; + ofstream out_txt_file; + out_txt_file.open(name, ios::out | ios::trunc); + out_txt_file << fixed; + for(int i=0; i < size; i++){ + out_txt_file << setprecision(8) << static_cast(t[i]) << endl; + } + out_txt_file.close(); +} // for debug // #define _DEBUG_FUSED_MULTI_TRANSFORMER template -static void AllReduce(framework::Tensor &tensor, // NOLINT +static void AllReduce(phi::DenseTensor &tensor, // NOLINT const int ring_id, const int count, const phi::GPUContext &ctx) { @@ -91,6 +105,9 @@ using float16 = plat::float16; #define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_OUT +#define MMHA_USE_FP32_ACUM_FOR_FMA +// #define MMHA_USE_HMMA_FOR_REDUCTION + template struct Masked_multihead_attention_params { @@ -108,8 +125,18 @@ struct Masked_multihead_attention_params { // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first // v [B, num_head, max_seq_len, dim_head] T *cache_kv; + // [B, max_seq_len] + const int* beam_cache_offset = nullptr; + + const int *sequence_lengths{nullptr}; + + // The RoPE embedding, [B, 1, 1, dim_head] + // rotary_emb_dims = 1 if pos_ids_extra is null else 2 + const T *rotary_emb; + int rotary_emb_dims; - int batch_size; + int batch_size; // batch * beam + int beam_width; int num_head; int timestep; // cache_seq_length int max_seq_length; @@ -153,6 +180,17 @@ template <> struct V_vec_ { using Type = uint32_t; }; template <> struct V_vec_ { using Type = uint2; }; template <> struct V_vec_ { using Type = uint4; }; +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +#endif + #ifdef MMHA_USE_FP32_ACUM_FOR_OUT template struct V_vec_acum_fp32_ {}; // template <> struct V_vec_acum_fp32_ { using Type = float; }; @@ -321,6 +359,15 @@ inline __device__ uint32_t mul(uint32_t a, float b) { return res; } +template <> +inline __device__ float2 mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 res; + res.x = tmp.x * b; + res.y = tmp.y * b; + return res; +} + template <> inline __device__ uint2 mul(uint2 a, float b) { uint2 res; @@ -347,6 +394,15 @@ inline __device__ float2 mul(float2 a, float b) { return res; } +template <> +inline __device__ float2 mul(float2 a, uint32_t b) { + float2 tmp_b = half2_to_float2(b); + float2 res; + res.x = a.x * tmp_b.x; + res.y = a.y * tmp_b.y; + return res; +} + template <> inline __device__ float4 mul(float4 a, float b) { float4 res; @@ -357,6 +413,18 @@ inline __device__ float4 mul(float4 a, float b) { return res; } +template +inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, + Qk_vec input_right, + Qk_vec cos_emb, + Qk_vec sin_emb, + float alpha) { + Qk_vec res1 = mul(input_left, cos_emb); + Qk_vec res2 = mul(input_right, sin_emb); + res2 = mul(res2, alpha); + return add(res1, res2); +} + inline __device__ float sum(float v) { return v; } inline __device__ float sum(float2 v) { return v.x + v.y; } inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } @@ -406,6 +474,12 @@ inline __device__ float2 fma(float2 a, float2 b, float2 c) { return d; } +inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { + float2 tmp_b = half2_to_float2(b); + float2 d = fma(a, tmp_b, c); + return d; +} + inline __device__ float4 fma(float4 a, float4 b, float4 c) { float4 d; d.x = fma(a.x, b.x, c.x); @@ -527,6 +601,50 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], return qk; } +inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { + float4 c; + float zero = 0.f; + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); + K_vec_acum qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + template struct Qk_dot { template @@ -537,6 +655,21 @@ struct Qk_dot { } }; +template <> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + return qk_hmma_dot_(q, k, inv_sqrt_dh); +#else + return qk_dot_<4>(q, k, inv_sqrt_dh); +#endif + } +}; + template inline __device__ float block_sum(float *red_smem, float sum) { int warp = threadIdx.x / WARP_SIZE; @@ -630,14 +763,24 @@ __global__ void masked_multihead_attention_kernel( using Qk_vec = typename Qk_vec_::Type; __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + // batch * beam idx const int bi = blockIdx.y; + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; const int hi = blockIdx.x; const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; const int tid = threadIdx.x; - + const int bi_seq_len_offset = bi * params.max_seq_length; float qk_max = -FLT_MAX; float qk = 0; + int act_time_step = params.sequence_lengths == nullptr + ? params.timestep + : params.sequence_lengths[bi]; + // qkv [B, S=1, 3, num_head, head_dim] int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; @@ -690,13 +833,74 @@ __global__ void masked_multihead_attention_kernel( // we may not require k_bias. k = add(k, k_bias); + // rotary pos emb + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const T *cos_base = params.rotary_emb; + const T *sin_base = params.rotary_emb + params.batch_size * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + q_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_base[qk_right_offset]) + : q_right; + Qk_vec k_right; + zero(k_right); + k_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_base[qk_right_offset]) + : k_right; + + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + + Qk_vec cos_emb; + zero(cos_emb); + cos_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&cos_base[rotary_offset]) + : cos_emb; + + Qk_vec sin_emb; + zero(sin_emb); + sin_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha); + } + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; int co = tid / QK_VECS_IN_16B; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; int offset = bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + - params.timestep * QK_ELTS_IN_16B + ci; + act_time_step * QK_ELTS_IN_16B + ci; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { *reinterpret_cast(¶ms.cache_kv[offset]) = k; } @@ -710,6 +914,7 @@ __global__ void masked_multihead_attention_kernel( } } } + if (QK_VECS_PER_WARP > WARP_SIZE) { constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; @@ -722,7 +927,7 @@ __global__ void masked_multihead_attention_kernel( // qk += static_cast(mask); qk *= params.inv_sqrt_dh; qk_max = qk; - qk_smem[params.timestep] = qk; + qk_smem[act_time_step] = qk; } __syncthreads(); @@ -735,14 +940,14 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); #endif - using K_vec = typename K_vec_::Type; - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + using K_vec = typename K_vec_::Type; // uint2 + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); // 2 static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; // 32 + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; // 16 - int ko = tid / THREADS_PER_KEY; - int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + int ko = tid / THREADS_PER_KEY; // 0 ~ 63 + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; // 0 or 2 static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); @@ -753,11 +958,13 @@ __global__ void masked_multihead_attention_kernel( &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); } - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; //128/2 = 64 + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // 32/2 = 16 T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; - int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; // 160 + const int *beam_offsets = params.beam_cache_offset ? ¶ms.beam_cache_offset[bi_seq_len_offset] : nullptr; for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; @@ -765,12 +972,19 @@ __global__ void masked_multihead_attention_kernel( zero(k_vec_zero); #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + // get beam_offset of this location + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh : 0; int jj = ii * params.max_seq_length + ti; - if (ti < params.timestep) { + if (ti < act_time_step) { + // k[ii] = + // (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + // ? *reinterpret_cast( + // &k_cache[jj * QK_ELTS_IN_16B]) + // : k_vec_zero; k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) ? *reinterpret_cast( - &k_cache[jj * QK_ELTS_IN_16B]) + &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) : k_vec_zero; } } @@ -780,7 +994,7 @@ __global__ void masked_multihead_attention_kernel( float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); // bool is_mask = false; - if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { + if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; qk += static_cast(mask); @@ -822,7 +1036,7 @@ __global__ void masked_multihead_attention_kernel( #endif float sum = 0.f; - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { // bool is_mask = false; // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); float logit = __expf(qk_smem[ti] - qk_max); @@ -834,7 +1048,7 @@ __global__ void masked_multihead_attention_kernel( // FIXME(wangxi): need add 1.e-6f? float inv_sum = __fdividef(1.f, sum + 1.e-6f); - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); } __syncthreads(); @@ -848,6 +1062,9 @@ __global__ void masked_multihead_attention_kernel( T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * params.max_seq_length * Dh + bhi * params.max_seq_length * Dh + vi]; + T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bbhi * params.max_seq_length * Dh + vi]; #ifdef MMHA_USE_FP32_ACUM_FOR_OUT using V_vec_acum = typename V_vec_acum_fp32_::Type; @@ -860,8 +1077,9 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { - for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); + for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh : 0; + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh]); #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); @@ -884,18 +1102,18 @@ __global__ void masked_multihead_attention_kernel( V_vec v_bias; zero(v_bias); - if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { + if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { V_vec v = *reinterpret_cast( ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); v_bias = *reinterpret_cast( ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); v = add(v, v_bias); - *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; + *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[act_time_step], cast_to_float(v), out); #else - out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[act_time_step], v, out); #endif } @@ -970,18 +1188,17 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + constexpr auto kernel_fn = masked_multihead_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>(params) template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, @@ -990,7 +1207,12 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, if (params.timestep < 32) { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); +#else MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); +#endif } else { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); } @@ -998,16 +1220,21 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, template void fmha(const phi::GPUContext &dev_ctx, - const Tensor &qkv_tensor, - const Tensor &qkv_bias_tensor, - const Tensor &src_mask_tensor, - Tensor *cache_kv_tensor, - Tensor *out_tensor, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &qkv_bias_tensor, + const phi::DenseTensor &src_mask_tensor, + const phi::DenseTensor *sequence_lengths_tensor, + const phi::DenseTensor *rotary_tensor, + const phi::DenseTensor *beam_cache_offset_tensor, + phi::DenseTensor *cache_kv_tensor, + phi::DenseTensor *out_tensor, int batch_size, + int beam_width, int max_seq_length, int num_head, int dim_head, int timestep, + int rotary_emb_dims, float inv_sqrt_dh) { Masked_multihead_attention_params params; params.out = out_tensor->data(); @@ -1016,11 +1243,28 @@ void fmha(const phi::GPUContext &dev_ctx, params.attn_mask = src_mask_tensor.data(); params.cache_kv = cache_kv_tensor->data(); + if (sequence_lengths_tensor) { + params.sequence_lengths = sequence_lengths_tensor->data(); + } + + if (rotary_emb_dims > 0) { + params.rotary_emb = rotary_tensor->data(); + } else { + params.rotary_emb = nullptr; + } + + if (beam_cache_offset_tensor) { + // LOG(INFO) << "beam_cache_offset_tensor.dims: " << beam_cache_offset_tensor->dims().to_str(); + params.beam_cache_offset = beam_cache_offset_tensor->data(); + } + params.batch_size = batch_size; + params.beam_width = beam_width; params.num_head = num_head; params.timestep = timestep; params.max_seq_length = max_seq_length; params.inv_sqrt_dh = inv_sqrt_dh; + params.rotary_emb_dims = rotary_emb_dims; switch (dim_head) { case 10: @@ -1050,17 +1294,54 @@ void fmha(const phi::GPUContext &dev_ctx, } } +template +void fmha(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &qkv_bias_tensor, + const phi::DenseTensor &src_mask_tensor, + phi::DenseTensor *cache_kv_tensor, + phi::DenseTensor *out_tensor, + int batch_size, + int max_seq_length, + int num_head, + int dim_head, + int timestep, + float inv_sqrt_dh) { + fmha(dev_ctx, + qkv_tensor, + qkv_bias_tensor, + src_mask_tensor, + nullptr, + nullptr, + nullptr, + cache_kv_tensor, + out_tensor, + batch_size, + 1, + max_seq_length, + num_head, + dim_head, + timestep, + 0, + inv_sqrt_dh); +} + // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 constexpr int VEC_16B = 16; template __global__ void write_cache_k_kernel(T *cache_k, const T *k, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; + if (seq_lens && seq_lens[bi] == 0) { + return; + } + const int hi = blockIdx.z; constexpr int X_ELEMS = VEC_16B / sizeof(T); @@ -1094,11 +1375,16 @@ __global__ void write_cache_k_kernel(T *cache_k, template __global__ void write_cache_v_kernel(T *cache_v, const T *v, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; + if (seq_lens && seq_lens[bi] == 0) { + return; + } + const int hi = blockIdx.z; // [bsz, num_head, seq_len, dim_head/x, x] @@ -1124,6 +1410,7 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, T *cache_v, const T *k, const T *v, + const int *seq_lens, const int bsz, const int num_head, const int seq_len, @@ -1147,14 +1434,496 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, // transpose [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, dim_head/x, max_seq_len, x] write_cache_k_kernel<<>>( - cache_k, k, num_head, dim_head, seq_len, max_seq_len); + cache_k, k, seq_lens, num_head, dim_head, seq_len, max_seq_len); // copy [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, max_seq_len, dim_head/x, x] write_cache_v_kernel<<>>( - cache_v, v, num_head, dim_head, seq_len, max_seq_len); + cache_v, v, seq_lens, num_head, dim_head, seq_len, max_seq_len); +} + +template +void write_cache_kv(const phi::GPUContext &dev_ctx, + T *cache_k, + T *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head) { + write_cache_kv(dev_ctx, + cache_k, + cache_v, + k, v, nullptr, + bsz, num_head, seq_len, + max_seq_len, dim_head); +} + +template +__global__ void add_fusedQKV_bias_transpose_split_kernel( + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int *padding_offset, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = batch_size * seq_len * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + if (ComputeBias) { + phi::Load(&qkv_bias[bias_idx], &bias_vec); +#pragma unroll + for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + src_vec[unroll_idx] += bias_vec[unroll_idx]; + } + } + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + const int32_t seq_id = ori_token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { + phi::Store( + src_vec, + &q_buf[target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } else { + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Store( + src_vec, + &kv_buf[kv_store_offset + + target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } + } +} + +inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { + constexpr int kBlockSize = 128; + constexpr int kNumWaves = 16; + + const int device_id = phi::backends::gpu::GetCurrentDeviceId(); + const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); + const int max_thread_per_multiprocessor = + phi::backends::gpu::GetGPUMultiProcessors(device_id); + + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * max_thread_per_multiprocessor / + kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int *padding_offset, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head, + bool compute_bias) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + if (compute_bias) { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + padding_offset, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } else { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + padding_offset, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } +} + +template +__global__ void RotrayKernel(const T *input, + const T *cos_emb, + const T *sin_emb, + const int *sequence_lengths, + T *output, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return; + int half_lastdim = last_dim / 2; + // Note(ZhenyuLi): Calculate the relevant data at one time, so that no + // additional space is required. + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * seq_len * last_dim + + hi * seq_len * last_dim + si * last_dim; + int left_idx = base_idx + ti; + const int right_idx = base_idx + ti + half_lastdim; + int emb_idx = bi * seq_len * last_dim + si * last_dim + ti; + T input_left = input[left_idx]; + T input_right = input[right_idx]; + T cos_tmp = cos_emb[emb_idx]; + T sin_tmp = sin_emb[emb_idx]; + T res1 = input_left * cos_tmp - input_right * sin_tmp; + T res2 = input_right * cos_tmp + input_left * sin_tmp; + output[left_idx] = res1; + output[right_idx] = res2; + } +} + +template +void rotary_qk(const phi::GPUContext &dev_ctx, + T *q, + T *k, // kv + const T *q_input, // q + const T *k_input, // kv + const T *rotary_emb, + const int *sequence_lengths, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int dim_head) { + // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] + // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs, + // 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head / + // rotary_emb_dims] + dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + const int last_dim = dim_head / rotary_emb_dims; + auto getBlockSize = [](int dim) { + if (dim > 256) { + return 512; + } else if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } + }; + int BlockSize = getBlockSize(last_dim / 2); + const T *cos_emb = rotary_emb; + const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head; + RotrayKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + RotrayKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); +} + +__global__ void GetPaddingOffset(int *d_token_num, + int *padding_offset, + const int *sequence_lengths, + const int batch_size, + const int max_seq_len) { + // get padding offset of each batch + int total_seq_len = 0; + int cum_offset = 0; + int index = 0; + for (int i = 0; i < batch_size; i++) { + const int seq_len = sequence_lengths[i]; + for (int j = 0; j < seq_len; j++) { + padding_offset[index] = cum_offset; + index++; + } + cum_offset += max_seq_len - seq_len; + total_seq_len += seq_len; + } + d_token_num[0] = total_seq_len; +} + +void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, + int *h_token_num, + int *d_token_num, + int *padding_offset, + const int *sequence_lengths, + const int batch_size, + const int max_seq_len) { + GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>( + d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len); + memory::Copy(platform::CPUPlace(), + h_token_num, + dev_ctx.GetPlace(), + d_token_num, + sizeof(int), + dev_ctx.stream()); } +template +__global__ void RemovePadding(T *output_data, + const T *input_data, + const int *padding_offset, + const int dim_embed) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int src_seq_id = bid + padding_offset[bid]; + const int tgt_seq_id = bid; + + for (int i = tid; i < dim_embed; i += blockDim.x) { + output_data[tgt_seq_id * dim_embed + i] = + input_data[src_seq_id * dim_embed + i]; + } +} + +template +void InvokeRemovePadding(const phi::GPUContext &dev_ctx, + T *output_data, + const T *input_data, + const int *padding_offset, + const int token_num, + const int dim_embed) { + RemovePadding<<>>( + output_data, input_data, padding_offset, dim_embed); +} + +template +__global__ void RebuildPadding(T *output_data, + const T *input_data, + const int *padding_offset, + const int dim_embed) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int dst_seq_id = bid + padding_offset[bid]; + const int src_seq_id = bid; + + for (int i = tid; i < dim_embed; i += blockDim.x) { + output_data[dst_seq_id * dim_embed + i] = + input_data[src_seq_id * dim_embed + i]; + } +} + +template +void InvokeRebuildPadding(const phi::GPUContext &dev_ctx, + T *output_data, + const T *input_data, + const int *padding_offset, + const int token_num, + const int dim_embed) { + // src: [token_num, dim_embed] + // dst: [batch_size * max_seq_len, dim_embed] + RebuildPadding<<>>( + output_data, input_data, padding_offset, dim_embed); +} + +template +__global__ void InitOutValueKernel(T *output_data, + const int64_t numel, + const T init_value) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + int64_t global_thread_idx = bid * blockDim.x + tid; + + for (int linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < numel; + linear_index += step) { + for (int i = 0; i < VecSize; i ++) { + output_data[linear_index + i] = init_value; + } + } +} + +template +void InitValue(const phi::GPUContext &dev_ctx, + T *output_data, + const int64_t numel, + const T init_value) { + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(numel % PackSize, + 0, + platform::errors::PreconditionNotMet( + "numel=%d must be divisible by vec_size=%d", + numel, + PackSize)); + const int pack_num = numel / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + InitOutValueKernel<<>>( + output_data, numel, init_value); +} + +template +__global__ void ActFFNGlu(const T *input, + T *output, + Functor act_functor, + const int token_num, + const int hid_dim, + const int elem_num) { + using LoadT = phi::AlignedVector; + LoadT src_vec1; + LoadT src_vec2; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int bi = i / hid_dim; + int idx = i % hid_dim; + const T *input_this_thread = input + bi * hid_dim * 2; + T *output_this_thread = output + bi * hid_dim; + phi::Load(&input_this_thread[idx], &src_vec1); + phi::Load(&input_this_thread[idx + hid_dim], &src_vec2); +#pragma unroll + for (int j = 0; j < VecSize; j++) { + src_vec1[j] = act_functor(src_vec1[j]); + src_vec1[j] *= src_vec2[j]; + } + phi::Store(src_vec1, &output_this_thread[idx]); + } +} + +template +class FFNGluHelper { + public: + FFNGluHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int token_num, + int hid_dim, + int dim_ffn, + int dim_embed) + : dev_ctx_(dev_ctx), + act_method_(act_method), + token_num_(token_num), + hid_dim_(hid_dim), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *bias, + phi::DenseTensor *bias_out, + phi::DenseTensor *output) { + // input's shape [token_num, dim_ffn], bias' shape [dim_ffn] + // output's shape [token_num, hid_dim], bias_out's shape [token_num, + // dim_ffn] + auto ffn_linear_compute = AttnMatMul( + dev_ctx_, false, false, token_num_, dim_ffn_, dim_embed_, true); + ffn_linear_compute.ComputeForward(weight, input, bias, bias_out, bias_out); + + using Functor = GeluFunctor; + + Functor functor; + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(T); + const int elem_cnt = token_num_ * hid_dim_; + const int blocksize = 128; + int grid_size = 1; + switch (hid_dim_ % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + ActFFNGlu + <<>>( + bias_out->data(), + output->data(), + functor, + token_num_, + hid_dim_, + elem_cnt); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + ActFFNGlu + <<>>( + bias_out->data(), + output->data(), + functor, + token_num_, + hid_dim_, + elem_cnt); + break; + } + } + + private: + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int token_num_; + int hid_dim_; + int dim_ffn_; + int dim_embed_; +}; + } // namespace } // namespace operators diff --git a/paddle/fluid/operators/fused/layernorm_quant_dequant.h b/paddle/fluid/operators/fused/layernorm_quant_dequant.h new file mode 100644 index 0000000000000..6d9fccbcb6e48 --- /dev/null +++ b/paddle/fluid/operators/fused/layernorm_quant_dequant.h @@ -0,0 +1,1076 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { + +constexpr int kWarpSize = 32; + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } +}; + +template class ReductionOp, typename T, int thread_group_width = kWarpSize> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width)); + } + return val; +} + +template class ReductionOp, typename T, int block_size> +__inline__ __device__ T BlockAllReduce(T val) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T result_broadcast; + T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); + if (threadIdx.x == 0) { result_broadcast = result; } + __syncthreads(); + return result_broadcast; +} + +template +__inline__ __device__ T Div(T a, T b); + +template<> +__inline__ __device__ float Div(float a, float b) { +#ifdef OF_LAYER_NORM_USE_FAST_MATH + return __fdividef(a, b); +#else + return a / b; +#endif +} + +template<> +__inline__ __device__ double Div(double a, double b) { + return a / b; +} + +template +__inline__ __device__ T Rsqrt(T x); + +template<> +__inline__ __device__ float Rsqrt(float x) { +#ifdef OF_LAYER_NORM_USE_FAST_MATH + return __frsqrt_rn(x); +#else + return rsqrt(x); +#endif +} + +template<> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +template +inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size, + int64_t max_blocks, int64_t waves, int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + int sm_count; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + int max_active_blocks; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, + block_size, dynamic_smem_size); + } + *num_blocks = + std::max(1, std::min(max_blocks, sm_count * max_active_blocks * waves)); + return cudaSuccess; +} + +template +struct DefaultComputeType { + using type = T; +}; + +template<> +struct DefaultComputeType { + using type = float; +}; + +#if CUDA_VERSION >= 11000 +template<> +struct DefaultComputeType { + using type = float; +}; +#endif // CUDA_VERSION >= 11000 + +template +class HasCanPackAs { + typedef char one; + struct two { + char x[2]; + }; + + template + static one test(decltype(&C::CanPackAs)); + template + static two test(...); + + public: + enum { value = sizeof(test(0)) == sizeof(char) }; +}; + +template +typename std::enable_if::value == true, bool>::type CanPackAs(T t, + size_t pack_size) { + return t.CanPackAs(pack_size); +} + +template +typename std::enable_if::value == false, bool>::type CanPackAs(T t, + size_t pack_size) { + return true; +} + +template +struct GetPackType { + using type = typename std::aligned_storage::type; +}; + +template +using PackType = typename GetPackType::type; + +template +union Pack { + static_assert(sizeof(PackType) == sizeof(T) * N, ""); + __device__ Pack() { + // do nothing + } + PackType storage; + T elem[N]; +}; + +template +struct DirectLoad { + using LoadType = DST; + DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack pack; + const int64_t offset = (row * row_size + col) / N; + pack.storage = *(reinterpret_cast*>(src) + offset); +#pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } + } + const SRC* src; + int64_t row_size; +}; + +template +struct DirectStore { + DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} + template + __device__ void store(const SRC* src, int64_t row, int64_t col) { + Pack pack; + const int64_t offset = (row * row_size + col) / N; +#pragma unroll + for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } + *(reinterpret_cast*>(dst) + offset) = pack.storage; + } + DST* dst; + int64_t row_size; +}; + +template +inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + T delta1 = val - *mean; + *mean += Div(delta1, *count); + T delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +template +inline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) { + if (b_count == 0) { return; } + T new_count = *count + b_count; + T nb_over_n = Div(b_count, new_count); + T delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean, + T* m2, T* count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width); + T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width); + T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean, + T* m2, T* count) { + WelfordWarpReduce(thread_mean, thread_m2, thread_count, mean, m2, count); + *mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width); + *m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width); + *count = __shfl_sync(0xffffffff, *count, 0, thread_group_width); +} + +template +__inline__ __device__ T WarpReduceSum(T x) { + T result = 0.0f; + #pragma unroll + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + result += __shfl_xor_sync(0xffffffff, x, mask, thread_group_width); + } + return result; +} + +template +__inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count, + T* result_mean, T* result_m2, T* result_count) { + __shared__ T mean_shared[kWarpSize]; + __shared__ T m2_shared[kWarpSize]; + __shared__ T count_shared[kWarpSize]; + __shared__ T mean_result_broadcast; + __shared__ T m2_result_broadcast; + __shared__ T count_result_broadcast; + const int lid = threadIdx.x % kWarpSize; + const int wid = threadIdx.x / kWarpSize; + T warp_mean = 0; + T warp_m2 = 0; + T warp_count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); + __syncthreads(); + if (lid == 0) { + mean_shared[wid] = warp_mean; + m2_shared[wid] = warp_m2; + count_shared[wid] = warp_count; + } + __syncthreads(); + if (wid == 0) { + if (threadIdx.x < blockDim.x / kWarpSize) { + warp_mean = mean_shared[lid]; + warp_m2 = m2_shared[lid]; + warp_count = count_shared[lid]; + } else { + warp_mean = static_cast(0); + warp_m2 = static_cast(0); + warp_count = static_cast(0); + } + __syncwarp(); + T block_mean = 0; + T block_m2 = 0; + T block_count = 0; + WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); + if (lid == 0) { + mean_result_broadcast = block_mean; + m2_result_broadcast = block_m2; + count_result_broadcast = block_count; + } + } + __syncthreads(); + *result_mean = mean_result_broadcast; + *result_m2 = m2_result_broadcast; + *result_count = count_result_broadcast; +} + +template +__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + using LoadType = typename LOAD::LoadType; + static_assert(max_cols_per_thread % pack_size == 0, ""); + static_assert(min_cols_per_thread % pack_size == 0, ""); + static_assert(thread_group_width <= kWarpSize, ""); + static_assert(kWarpSize % thread_group_width == 0, ""); + constexpr int max_num_packs = max_cols_per_thread / pack_size; + constexpr int min_num_packs = min_cols_per_thread / pack_size; + assert(cols <= max_cols_per_thread * thread_group_width); + ComputeType buf[rows_per_access][max_cols_per_thread]; + const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int64_t num_global_thread_group = gridDim.x * blockDim.y; + const int64_t lane_id = threadIdx.x; + const int64_t step = num_global_thread_group * rows_per_access; + for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { + ComputeType thread_mean[rows_per_access]; + ComputeType thread_m2[rows_per_access]; + ComputeType thread_count[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + thread_mean[row_id] = 0; + thread_m2[row_id] = 0; + thread_count[row_id] = 0; + ComputeType* row_buf = buf[row_id]; +#pragma unroll + for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + const int pack_offset = pack_id * pack_size; + LoadType pack[pack_size]; + load.template load(pack, row + row_id, col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + row_buf[pack_offset + i] = static_cast(pack[i]); + WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, + thread_count + row_id); + } + } + for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + const int pack_offset = pack_id * pack_size; + if (!padding || col < cols) { + LoadType pack[pack_size]; + load.template load(pack, row + row_id, col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + row_buf[pack_offset + i] = static_cast(pack[i]); + WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, + thread_count + row_id); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; } + } + } + } + ComputeType warp_mean[rows_per_access]; + ComputeType warp_m2[rows_per_access]; + ComputeType warp_count[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + int global_row_id = row + row_id; + ComputeType* row_buf = buf[row_id]; + WelfordWarpAllReduce( + thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id, + warp_m2 + row_id, warp_count + row_id); + ComputeType row_mean = warp_mean[row_id]; + ComputeType row_variance = + max(Div(warp_m2[row_id], warp_count[row_id]), static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (lane_id == 0) { + mean[global_row_id] = row_mean; + inv_variance[global_row_id] = row_inv_var; + } +#pragma unroll + for (int i = 0; i < max_cols_per_thread; ++i) { + row_buf[i] = (row_buf[i] - row_mean) * row_inv_var; + } +#pragma unroll + for (int i = 0; i < min_num_packs; ++i) { + const int col = (i * thread_group_width + lane_id) * pack_size; + store.template store(row_buf + i * pack_size, global_row_id, col); + } +#pragma unroll + for (int i = min_num_packs; i < max_num_packs; ++i) { + const int col = (i * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + store.template store(row_buf + i * pack_size, global_row_id, col); + } + } + } + } +} + +template +inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % thread_group_width == 0, ""); + constexpr int thread_groups_per_block = block_size / thread_group_width; + dim3 block_dim(thread_group_width, thread_groups_per_block); + const int64_t num_blocks = + (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; + int grid_dim_x; + { + cudaError_t err = GetNumBlocks( + LayerNormWarpImpl, + block_size, 0, num_blocks, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormWarpImpl + <<>>(load, store, rows, cols, epsilon, mean, inv_variance); + return cudaPeekAtLastError(); +} + +template +inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols == max_cols_per_thread * thread_group_width) { + // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass + // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param. + return LaunchLayerNormWarpImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else { + return LaunchLayerNormWarpImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } +} + +template +typename std::enable_if::type DispatchLayerNormWarpImplCols( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } else { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } \ + } + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(max_col, min_col) \ + else if (cols <= (max_col)*kWarpSize) { \ + return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ + epsilon, mean, inv_variance); \ + } + DEFINE_ONE_ELIF(2, 1) + DEFINE_ONE_ELIF(4, 2) + DEFINE_ONE_ELIF(8, 4) + DEFINE_ONE_ELIF(12, 8) + DEFINE_ONE_ELIF(16, 12) + DEFINE_ONE_ELIF(20, 16) + DEFINE_ONE_ELIF(24, 20) + DEFINE_ONE_ELIF(28, 24) + +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + +template +typename std::enable_if::type DispatchLayerNormWarpImplCols( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } else { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } \ + } + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(max_col, min_col) \ + else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \ + return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ + epsilon, mean, inv_variance); \ + } + DEFINE_ONE_ELIF(4, 2) + DEFINE_ONE_ELIF(8, 4) + DEFINE_ONE_ELIF(12, 8) + DEFINE_ONE_ELIF(16, 12) + DEFINE_ONE_ELIF(20, 16) + DEFINE_ONE_ELIF(24, 20) + DEFINE_ONE_ELIF(28, 24) + DEFINE_ONE_ELIF(32, 28) + DEFINE_ONE_ELIF(48, 44) + DEFINE_ONE_ELIF(52, 48) +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + + +template +struct DispatchLayerNormWarpImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return DispatchLayerNormWarpImplCols( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + + else { + return DispatchLayerNormWarpImplCols( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + } +}; + +template +inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormWarpImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, + ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor) { + using LoadType = typename LOAD::LoadType; + extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; + auto* buf = reinterpret_cast(shared_buf); + const int tid = threadIdx.x; + assert(cols % pack_size == 0); + const int num_packs = static_cast(cols) / pack_size; + for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { + ComputeType thread_sum = 0; + ComputeType thread_sum_square = 0; + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + load.template load(pack, row, pack_id * pack_size); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + buf[i * num_packs + pack_id] = pack[i]; + ComputeType pack_val = static_cast(pack[i]); + thread_sum += pack_val; + thread_sum_square += pack_val * pack_val; + } + } + + const ComputeType row_sum = BlockAllReduce(thread_sum); + const ComputeType row_sum_square = BlockAllReduce(thread_sum_square); + + // use multiply instead of divide. + ComputeType row_mean = row_sum * col_divisor; + ComputeType row_sum_square_mean = row_sum_square * col_divisor; + ComputeType row_variance = max(row_sum_square_mean - row_mean * row_mean, static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (threadIdx.x == 0) { + mean[row] = row_mean; + inv_variance[row] = row_inv_var; + } + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + ComputeType pack[pack_size]; +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] = (static_cast(buf[i * num_packs + pack_id]) - row_mean) * row_inv_var; + } + store.template store(pack, row, pack_id * pack_size); + } + } +} + + +template +inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, + int smem, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor) { + constexpr int waves = 32; + int grid_dim_x; + { + cudaError_t err = + GetNumBlocks(LayerNormBlockSMemImpl, + block_size, smem, rows, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormBlockSMemImpl + <<>>(load, store, rows, cols, epsilon, mean, + inv_variance, col_divisor); + return cudaPeekAtLastError(); +} + +template +cudaError_t MaximizeDynamicSharedMemorySize(Func func, const int max_smem_size) { + cudaFuncAttributes attr{}; + cudaError_t err = cudaFuncGetAttributes(&attr, func); + if (err != cudaSuccess) { return err; } + constexpr int reserved_smem = 1024; // 1K + return cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_smem_size - attr.sharedSizeBytes - reserved_smem); +} + +template +inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance, ComputeType col_divisor, bool* success) { + constexpr int block_size_conf_1 = 128; + constexpr int block_size_conf_2 = 256; + constexpr int block_size_conf_3 = 512; + constexpr int block_size_conf_4 = 1024; + + int dev = 0; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + + int sm_count = 0; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + + static const bool max_smem_configed = [=]() { + int max_smem_size = 0; + cudaError_t err = + cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + if (err != cudaSuccess) { return false; } + + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + + return true; + }(); + + const size_t smem = cols * sizeof(typename LOAD::LoadType); + + int max_active_blocks_conf_1; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_1, + LayerNormBlockSMemImpl, + block_size_conf_1, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_1 <= 0) { + *success = false; + return cudaSuccess; + } + + int max_active_blocks_conf_4; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_4, + LayerNormBlockSMemImpl, + block_size_conf_4, smem); + if (err != cudaSuccess) { return err; } + } + + if (max_active_blocks_conf_4 == max_active_blocks_conf_1 + || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + int max_active_blocks_conf_3; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_3, + LayerNormBlockSMemImpl, + block_size_conf_3, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_3 == max_active_blocks_conf_1 + || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + int max_active_blocks_conf_2; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_2, + LayerNormBlockSMemImpl, + block_size_conf_2, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_2 == max_active_blocks_conf_1 + || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); +} + +template +struct TryDispatchLayerNormBlockSMemImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor, bool* success) { + if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } else { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } + } +}; + +template +inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor, bool* success) { + return TryDispatchLayerNormBlockSMemImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); +} + +template +__global__ void __launch_bounds__(1024) + LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + using LoadType = typename LOAD::LoadType; + const int tid = threadIdx.x; + assert(cols % pack_size == 0); + const int num_packs = static_cast(cols) / pack_size; + for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { + ComputeType thread_mean = 0; + ComputeType thread_m2 = 0; + ComputeType thread_count = 0; + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + load.template load(pack, row, pack_id * pack_size); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + WelfordCombine(static_cast(pack[i]), &thread_mean, &thread_m2, &thread_count); + } + } + ComputeType row_mean = 0; + ComputeType row_m2 = 0; + ComputeType row_count = 0; + WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2, + &row_count); + ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (threadIdx.x == 0) { + mean[row] = row_mean; + inv_variance[row] = row_inv_var; + } + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + ComputeType dst_pack[pack_size]; + const int pack_offset = pack_id * pack_size; + load.template load(pack, row, pack_offset); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + dst_pack[i] = (static_cast(pack[i]) - row_mean) * row_inv_var; + } + store.template store(dst_pack, row, pack_offset); + } + } +} + +template +inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + constexpr int block_size = 1024; + constexpr int waves = 32; + int grid_dim_x; + { + cudaError_t err = + GetNumBlocks(LayerNormBlockUncachedImpl, + block_size, 0, rows, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormBlockUncachedImpl + <<>>(load, store, rows, cols, epsilon, mean, inv_variance); + return cudaPeekAtLastError(); +} + +template +struct DispatchLayerNormBlockUncachedImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + } +}; + +template +inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormBlockUncachedImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + const ComputeType col_divisor = 1.0f / cols; + if (cols <= 1024) { + return DispatchLayerNormWarpImpl(stream, load, store, rows, cols, + epsilon, mean, inv_variance); + } else { + bool dispatch_smem_impl_success; + { + cudaError_t err = TryDispatchLayerNormBlockSMemImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, + &dispatch_smem_impl_success); + if (err != cudaSuccess) { return err; } + } + if (!dispatch_smem_impl_success) { + return DispatchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + return cudaSuccess; + } +} + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +struct DequantSkipLoad { + using LoadType = DST; + DequantSkipLoad(const InputType* src, const SRC* bias, const SRC* skip, const float* dequant_scale, float alpha, int64_t row_size) + : src(src), bias(bias), skip(skip), dequant_scale(dequant_scale), alpha(alpha), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack src_pack; + Pack bias_pack; + Pack skip_pack; + Pack dequant_scale_pack; + const int64_t offset = (row * row_size + col) / N; + const int64_t bias_offset = col / N; + src_pack.storage = *(reinterpret_cast*>(src) + offset); + bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); + skip_pack.storage = *(reinterpret_cast*>(skip) + offset); + dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. +#pragma unroll + for (int i = 0; i < N; ++i) { + // First we need to cast src and dequant. + dst[i] = static_cast(static_cast(static_cast(src_pack.elem[i]) * dequant_scale_pack.elem[i]) + + bias_pack.elem[i] + + skip_pack.elem[i]); + } + } + const InputType* src; + const SRC* bias; + const SRC* skip; + const float* dequant_scale; + double alpha; + int64_t row_size; +}; + +template +__device__ __inline__ +T ClipFunc(const T v, const T min, const T max){ + if(v > max) return max; + if(v < min) return min; + return v; +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + + if (round_type == 0) { + // quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } +// quant_value = quant_value > max_bound ? max_bound : quant_value; +// quant_value = quant_value < min_bound ? min_bound : quant_value; +// return static_cast(quant_value); + return static_cast(ClipFunc(quant_value, min_bound, max_bound)); +} + +template +struct AffineQuantStore { + AffineQuantStore(OutType* y, + const int64_t row_size, + const float* gamma, const float* beta, + const float quant_out_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) : y(y), row_size(row_size), gamma(gamma), beta(beta), quant_round_type(quant_round_type), + quant_out_scale(quant_out_scale), quant_max_bound(quant_max_bound), quant_min_bound(quant_min_bound) {} + + template + __device__ void store(const SRC* src, int64_t row, int64_t col) { + Pack y_pack; + Pack gamma_pack; + Pack beta_pack; + // Pack out_pack; + const int64_t offset = (row * row_size + col) / N; + const int64_t gamma_offset = col / N; + gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_offset); + beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_offset); +#pragma unroll + for (int i = 0; i < N; ++i) { + float normalized_i = static_cast(src[i]); + float normalized_val = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; + if (do_scale) { + y_pack.elem[i] = QuantHelperFunc(normalized_val, quant_out_scale, quant_round_type, quant_max_bound, quant_min_bound); + } else { + y_pack.elem[i] = static_cast(normalized_val); + } + } + *(reinterpret_cast*>(y) + offset) = y_pack.storage; + } + + OutType* y; + int64_t row_size; + const float* gamma; + const float* beta; + const int quant_round_type; + const float quant_out_scale; + const float quant_max_bound; + const float quant_min_bound; +}; + +template +struct DequantSkipLoadAndStoreResidual { + using LoadType = DST; + // need to aseert SRC equals to DST. + DequantSkipLoadAndStoreResidual(const InputType* src, + const SRC* bias, + const SRC* skip, + const float* dequant_scale, + SRC* residual_bias_out, + float alpha, int64_t row_size) + : src(src), bias(bias), skip(skip), dequant_scale(dequant_scale), residual_bias_out(residual_bias_out), alpha(alpha), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack src_pack; + Pack bias_pack; + Pack skip_pack; + Pack dequant_scale_pack; + Pack residual_out_pack; + + const int64_t offset = (row * row_size + col) / N; + const int64_t bias_offset = col / N; + src_pack.storage = *(reinterpret_cast*>(src) + offset); + bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); + skip_pack.storage = *(reinterpret_cast*>(skip) + offset); + dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. +#pragma unroll + for (int i = 0; i < N; ++i) { + // First we need to cast src and dequant. + if (do_dequant) { + residual_out_pack.elem[i] = static_cast(static_cast(static_cast(src_pack.elem[i]) * dequant_scale_pack.elem[i]) + + bias_pack.elem[i] + + skip_pack.elem[i]); + } else { + residual_out_pack.elem[i] = static_cast(static_cast(src_pack.elem[i]) + bias_pack.elem[i] + + skip_pack.elem[i]); + } + } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = residual_out_pack.elem[i]; + } + *(reinterpret_cast*>(residual_bias_out) + offset) = residual_out_pack.storage; + } + const InputType* src; + const SRC* bias; + const SRC* skip; + const float* dequant_scale; + SRC* residual_bias_out; + double alpha; + int64_t row_size; +}; + +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h index 21b7b0f345466..bd490555a8b86 100644 --- a/paddle/fluid/operators/fused/quant_dequant_kernel.h +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -18,17 +18,85 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" namespace paddle { namespace operators { +using phi::backends::gpu::GpuLaunchConfig; + +constexpr int DequantKernelVecSize = 4; + + +template +struct QuantFunc{ + HOSTDEVICE int8_t operator()(const T x, const float scale, const float max_bound, + const float min_bound) { + float tmp = static_cast(x) * max_bound * scale; + tmp = round(tmp); + if (tmp > max_bound) + tmp = max_bound; + else if (tmp < min_bound) + tmp = min_bound; + return static_cast(tmp); + } +}; + +template +__global__ void QuantActKernel(const T* x, const int32_t rows, const int32_t cols, float scale, int8_t* quant_x, + const float max_bound, + const float min_bound) { + + using InVec = phi::AlignedVector; + using OutVec = phi::AlignedVector; + + const int stride = blockDim.x * gridDim.x * VecSize; + const int num_items = rows * cols; + + InVec in_vec; + OutVec out_vec; + for(int32_t linear_index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; linear_index < num_items; linear_index += stride){ + phi::Load(x + linear_index, &in_vec); + #pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = QuantFunc()(in_vec[i], scale, max_bound, min_bound); + } + phi::Store(out_vec, quant_x + linear_index); + } +} + + +template +void LaunchQuantActKernel(const T* x, const int32_t rows, const int32_t cols, int8_t* quant_x, float scale, + const float max_bound, const float min_bound, gpuStream_t stream) { + constexpr int NumThreads=256; + constexpr int VecSize= 16 / sizeof(T); + + constexpr int kNumWaves = 8; + int dev; + cudaGetDevice(&dev); + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + int tpm; + cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + const int elem_cnt = rows*cols; + const int launch_elem_cnt = elem_cnt / VecSize; + const int grid_size = std::max(1, std::min((launch_elem_cnt + NumThreads - 1) / NumThreads, + sm_count * tpm / NumThreads * kNumWaves)); + + QuantActKernel<<>>(x, rows, cols, scale, quant_x, max_bound, min_bound); +} + + template __forceinline__ __device__ int8_t quant_helper(const T input, const float scale, const int round_type, const float max_bound, const float min_bound) { - float quant_value = max_bound * inverse(scale) * static_cast(input); + float quant_value = max_bound * scale * static_cast(input); + if (round_type == 0) { quant_value = static_cast(roundWithTiesToEven(quant_value)); } else { @@ -77,7 +145,7 @@ void quantize_kernel_launcher(const T* input, const float min_bound, gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 - dim3 grid((n + 31) / 32, (m + 31) / 32); + dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); quantize_kernel<<>>(input, @@ -90,47 +158,49 @@ void quantize_kernel_launcher(const T* input, min_bound); } -// dequantize using weight scales and input scales -template +template __global__ void dequantize_kernel(T* output, const int32_t* input, - const int m, // hidden - const int n, // batch size + const int m, // batch size + const int n, // hidden const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden - int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id]; - output[n_id * m + m_id] = - static_cast(static_cast(input[n_id * m + m_id]) * - quant_in_scale / out_scale); + const float* dequant_out_scale_data) { + int numel = m * n; + int stride = blockDim.x * gridDim.x * VecSize; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = + static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + } + + phi::Store(out_vec, output + idx); } } template void dequantize_kernel_launcher(const int32_t* input, T* output, - const int batch_size, // m - const int hidden_units, // n + const int m, // m + const int n, // n gpuStream_t stream, + GpuLaunchConfig* gpu_config, const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32); - dim3 block(32, 32); - - dequantize_kernel<<>>(output, - input, - hidden_units, - batch_size, - quant_in_scale, - dequant_out_scale_data, - quant_out_scale_offset); + const float* dequant_out_scale_data) { + dequantize_kernel + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, input, m, n, quant_in_scale, dequant_out_scale_data); } } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index c3425ac604858..43a62ac2f9742 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -39,7 +39,7 @@ namespace dynload { extern DynLoad__##__name __name // APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -61,7 +61,33 @@ namespace dynload { __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif @@ -69,4 +95,4 @@ CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) #undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP } // namespace dynload } // namespace platform -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index af080bd0b3431..661e0edcd441d 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -61,7 +61,9 @@ std::map> op_ins_map = { "QKVW", "QKVBias", "CacheKV", + "BeamCacheOffset", "TimeStep", + "SeqLengths", "SrcMask", "OutLinearW", "OutLinearBias", @@ -71,12 +73,39 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_multi_transformer_moe", + {"X", + "LnScale", + "LnBias", + "QKVW", + "QKVBias", + "CacheKV", + "BeamCacheOffset", + "TimeStep", + "SeqLengths", + "SrcMask", + "OutLinearW", + "OutLinearBias", + "GateWeight", + "GateBias", + "FFNLnScale", + "FFNLnBias", + "ExpertWeight1", + "ExpertBias1", + "ExpertWeight2", + "ExpertBias2"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", - "QKVBias", "CacheKV", "TimeStep", "SrcMask", + "QKVBias", "CacheKV", "BeamCacheOffset", "TimeStep", "SeqLengths", "SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias", "FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias", "QKVOutScale", "OutLinearOutScale", "FFN1OutScale", "FFN2OutScale"}}, + {"fused_multi_transformer_moe_int8", + {"X", "LnScale", "LnBias", "QKVW", + "QKVBias", "CacheKV", "BeamCacheOffset", "TimeStep", "SeqLengths", "SrcMask", + "OutLinearW", "OutLinearBias", "GateWeight", "GateBias", "FFNLnScale", "FFNLnBias", + "ExpertWeight1", "ExpertBias1", "ExpertWeight2", "ExpertBias2", + "QKVOutScale", "OutLinearOutScale", "ExpertWeight1OutScale", "ExpertWeight2OutScale"}}, {"fused_bias_dropout_residual_layer_norm", {"X", "Residual", "Bias", "LnScale", "LnBias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, @@ -335,7 +364,9 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_moe", {"CacheKVOut", "Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_moe_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", "Var1Out", "Conv2", "SavedMean2", "SavedInvstd2", "Mean2Out", @@ -440,7 +471,9 @@ std::map> op_passing_outs_map = { {"split", {"Out"}}, {"concat", {"Out"}}, {"fused_multi_transformer", {"CacheKVOut"}}, + {"fused_multi_transformer_moe", {"CacheKVOut"}}, {"fused_multi_transformer_int8", {"CacheKVOut"}}, + {"fused_multi_transformer_moe_int8", {"CacheKVOut"}}, {"group_norm", {"Mean", "Variance"}}, {"resnet_basic_block", {"Mean1Out", "Var1Out", "Mean2Out", "Var2Out", "Mean3Out", "Var3Out"}}, diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d3cd00b3a541c..38bdacc3bcf6f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -228,7 +228,10 @@ data_type : numbers - op : fused_moe_kernel - args : (Tensor x, Tensor gate_weight, Tensor gate_bias, Tensor ln_scale, Tensor ln_bias, Tensor[] experts_weight1, Tensor[] experts_bias1, Tensor[] experts_weight2, Tensor[] experts_bias2, bool pre_layer_norm, float ln_epsilon, int topk, int mp_size, int mp_rank, int num_expert, int world_size, int moe_ring_id, bool approximate) + args : (Tensor x, Tensor residual, Tensor gate_weight, Tensor gate_bias, Tensor ln_scale, Tensor ln_bias, + Tensor[] experts_weight1, Tensor[] experts_bias1, Tensor[] experts_weight2, Tensor[] experts_bias2, + bool pre_layer_norm, float ln_epsilon, int topk, int mp_size, int mp_rank, int num_expert, int world_size, + int moe_ring_id, bool approximate, int bsz, int seq_len, int d_model, int dim_feedforward) output : Tensor(out) infer_meta : func : FusedMoeInferMeta diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 90492ff4ba69d..d078feef9c28e 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -54,6 +54,7 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -75,11 +76,37 @@ extern void *cublasLt_dso_handle; __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif #undef DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP } // namespace dynload -} // namespace phi +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 4a16480101a70..08975eca01948 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -174,13 +174,13 @@ void DestroyBlasHandle(blasHandle_t handle) { } void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10020 phi::dynload::cublasLtCreate(blaslt_handle); #endif } void DestroyBlasLtHandle(blasLtHandle_t handle) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10020 if (handle != nullptr) { phi::dynload::cublasLtDestroy(handle); handle = nullptr; diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0a45e1fb0530b..6340dbfd57401 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2949,6 +2949,7 @@ void GraphSendUVInferMeta(const MetaTensor& x, } void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& residual, const MetaTensor& gate_weight, const MetaTensor& gate_bias, const MetaTensor& ln_scale, @@ -2966,6 +2967,10 @@ void FusedMoeInferMeta(const MetaTensor& x, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 948c5ca75ac4c..2d1a45f31c0ce 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -545,6 +545,7 @@ void GraphSendUVInferMeta(const MetaTensor& x, MetaTensor* out); void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& residual, const MetaTensor& gate_weight, const MetaTensor& gate_bias, const MetaTensor& ln_scale, @@ -562,5 +563,9 @@ void FusedMoeInferMeta(const MetaTensor& x, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fused_moe_kernel.h b/paddle/phi/kernels/fused_moe_kernel.h index dbafdf3918025..d6a1f112c762c 100644 --- a/paddle/phi/kernels/fused_moe_kernel.h +++ b/paddle/phi/kernels/fused_moe_kernel.h @@ -49,9 +49,109 @@ #endif namespace phi { +using Tensor = DenseTensor; namespace framework = paddle::framework; namespace platform = paddle::platform; +template +static void AllToAll(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllToAll(in_tensor, out_tensor); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t send_numel = tensor.numel(); // send_numel + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + auto stream = ctx.stream(); + + framework::DDim x_dims = tensor.dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, + 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], + nranks)); + auto send_buf = tensor.data(); + auto recv_buf = out.mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +static void AllGather(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + auto out_dims = tensor.dims(); + int nranks = comm->nranks(); + out_dims[0] *= nranks; + out.mutable_data(out_dims, place); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + template void GlobalScatterFunctor(const phi::GPUContext& ctx, const framework::Tensor* x, @@ -142,6 +242,11 @@ void GlobalScatterFunctor(const phi::GPUContext& ctx, } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif #else PADDLE_THROW( @@ -337,6 +442,13 @@ void GlobalGatherFunctor(const phi::GPUContext& ctx, } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + #else PADDLE_THROW( platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); @@ -480,6 +592,7 @@ void MatMulAndAdd(const phi::GPUContext& dev_ctx, template void FusedMoeKernel(const DeviceContext& context, const DenseTensor& x, + const DenseTensor& residual, const DenseTensor& gate_weight, const DenseTensor& gate_bias, const DenseTensor& ln_scale, @@ -497,6 +610,10 @@ void FusedMoeKernel(const DeviceContext& context, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu index 5652adfd1c50b..0deaa33c6adab 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -826,6 +826,7 @@ void invokeTopkSoftMax(const Context &dev_ctx, CASE_K(14); CASE_K(15); CASE_K(16); + CASE_K(50); default: PADDLE_THROW(paddle::platform::errors::Unimplemented( "beam_size = %d is unsupport!", beam_size)); diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu b/paddle/phi/kernels/gpu/fused_moe_kernel.cu index 0d72d7e3b058f..657f53b9e29a8 100644 --- a/paddle/phi/kernels/gpu/fused_moe_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu @@ -19,108 +19,10 @@ using Tensor = DenseTensor; namespace framework = paddle::framework; namespace platform = paddle::platform; -template -static void AllToAll(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t send_numel = tensor.numel(); // send_numel - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - int nranks = comm->nranks(); - auto stream = ctx.stream(); - - framework::DDim x_dims = tensor.dims(); - framework::DDim out_dims(x_dims); - PADDLE_ENFORCE_EQ( - x_dims[0] % nranks, - 0, - platform::errors::InvalidArgument( - "The first dimension size (%d) of the input tensor must be " - "divisible by the number of ranks (%d).", - x_dims[0], - nranks)); - auto send_buf = tensor.data(); - auto recv_buf = out.mutable_data(out_dims, place); - size_t offset = 0; - send_numel /= nranks; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto i = 0; i < nranks; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - offset += send_numel; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - -template -static void AllGather(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t numel = tensor.numel(); - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - auto out_dims = tensor.dims(); - int nranks = comm->nranks(); - out_dims[0] *= nranks; - out.mutable_data(out_dims, place); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( - tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - template void FusedMoeKernel(const DeviceContext& dev_ctx, const DenseTensor& x, + const DenseTensor& residual, const DenseTensor& gate_weight, const DenseTensor& gate_bias, const DenseTensor& ln_scale, @@ -138,18 +40,20 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, DenseTensor* out) { using U = paddle::operators::LayerNormParamType; - // output - dev_ctx.template Alloc(out); // dim auto x_dim = x.dims(); - int bsz = x_dim[0]; - int seq_len = x_dim[1]; + // output + out->Resize(x_dim); + dev_ctx.template Alloc(out); + // auto out_dim = out->dims(); int bsz_seq = bsz * seq_len; - int d_model = x_dim[2]; int tot_expert = world_size * num_expert; - int dim_feedforward = experts_weight1[0]->dims()[1]; // pre_layer_norm const U* ln_scale_ptr = ln_scale.data(); @@ -228,8 +132,11 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_gather_out.Resize({{bsz_seq, d_model}}); dev_ctx.template Alloc(&all_gather_out); paddle::operators::DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + // for naccl comm + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step1 layer norm + // VLOG(0) << "moe, layer norm"; if (pre_layer_norm) { pre_layernorm_helper.LayerNorm(dev_ctx, x.data(), @@ -241,6 +148,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } else { ln_out = x; } + // VLOG(0) << "moe, resize and slice ln_out"; // step2 resize and slice ln_out ln_out.Resize({{bsz_seq, d_model}}); if (mp_size > 1) { @@ -248,6 +156,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } else { sliced_inp = ln_out; } + // VLOG(0) << "moe, gate & topk"; // step3 gate & topk MatMulAndAdd(dev_ctx, &gate_weight, @@ -268,8 +177,10 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &topk_idx); // step4 prepare forward // step4.1 number count + // VLOG(0) << "moe, number count"; NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); // step4.2 all_to_all + // VLOG(0) << "moe, all_to_all"; if (world_size > 1) { AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); } else { @@ -278,6 +189,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, // global expert count resize global_expert_count.Resize({{world_size, num_expert}}); // fwd expert count + // VLOG(0) << "moe, fwd expert count"; SumKernel(dev_ctx, global_expert_count, IntArray({0}), @@ -285,6 +197,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, false, &fwd_expert_count); // fwd batch size + // VLOG(0) << "moe, fwd batch size"; SumKernel(dev_ctx, fwd_expert_count, IntArray({}), // axis is None @@ -292,6 +205,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, false, &fwd_batch_size); // step4.3 cumsum & assign pos + // VLOG(0) << "moe, cumsum & assign pos"; CumsumKernel(dev_ctx, local_expert_count, Scalar(0), @@ -332,10 +246,12 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, // step 5, MOEScatter // step 5.1, index select // suppose tmp_pos->shape != [0] + // VLOG(0) << "moe, index select"; IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step 5.2, global_scatter + // VLOG(0) << "moe, global_scatter"; if (map->has(moe_ring_id)) { GlobalScatterProcessGroupFunctor(dev_ctx, &index_select_out, @@ -358,6 +274,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } // step 6, Expert Computation + // VLOG(0) << "moe, Expert Computation"; if (fwd_bsz != 0) { int last_index = 0; for (int idx = 0; idx < num_expert; idx++) { @@ -380,6 +297,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, Tensor tmp_inp = global_scatter_out.Slice(last_index, end); // linear1 matmul + // VLOG(0) << "moe, Expert Computation, linear1 mul"; MatMulAndAdd(dev_ctx, experts_weight1[idx], &tmp_inp, @@ -390,6 +308,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &expert_out1, nullptr); // bias gelu + // VLOG(0) << "moe, Expert Computation, add bias & gelu"; fused_act_dropout_helper.DropoutActBias(dev_ctx, expert_out1.data(), experts_bias1[idx]->data(), @@ -405,6 +324,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, -127.0, approximate); // linear2 matmul & add + // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; MatMulAndAdd(dev_ctx, experts_weight2[idx], &act_bias_out, @@ -423,8 +343,9 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_expert_out = global_scatter_out; } // step7. MOEGather + // VLOG(0) << "moe, MOEGather"; if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step 7.1, global_gather if (map->has(moe_ring_id)) { GlobalGatherProcessGroupFunctor(dev_ctx, @@ -448,6 +369,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } // step 7.2, local_gather or scatter // suppose pos->shape != [0] + // VLOG(0) << "moe, local_gather or scatter"; ScatterKernel(dev_ctx, moe_gather_out, pos, @@ -456,11 +378,13 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &moe_gather_out); // step 8, reshape & bmm // moe gather out reshape + // VLOG(0) << "moe, reshape & bmm"; moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); topk_value.Resize({{sliced_bsz_seq, 1, topk}}); BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); bmm_out.Resize({{sliced_bsz_seq, d_model}}); // step 9, AllGather + // VLOG(0) << "moe, AllGather"; if (mp_size > 1) { // all gather AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); @@ -468,9 +392,20 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_gather_out = bmm_out; } // step 10, reshape + // VLOG(0) << "moe, reshape"; all_gather_out.Resize(x_dim); // step 11, add residual - AddKernel(dev_ctx, all_gather_out, x, out); + // VLOG(0) << "moe, add residual"; + AddKernel(dev_ctx, all_gather_out, residual, out); + if (!pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + out->data(), + ln_scale_ptr, + ln_bias_ptr, + out->data(), + ln_mean_data, + ln_variance_data); + } } } // namespace phi @@ -481,4 +416,4 @@ PD_REGISTER_KERNEL(fused_moe_kernel, phi::FusedMoeKernel, float, double, - paddle::platform::float16) {} \ No newline at end of file + paddle::platform::float16) {} diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index b2767b1dd1cbf..c081bffbef993 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -168,6 +168,7 @@ def _update_list(self): 'fused_feedforward', 'fused_attention', 'fused_multi_transformer', + 'fused_multi_transformer_moe', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e4994..8884caca96ba6 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -110,7 +110,7 @@ def _keep_fp32_input(op, in_name): return in_name in { 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" } - if op_type == 'fused_multi_transformer': + if op_type in ['fused_multi_transformer', 'fused_multi_transformer_moe']: return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} return False diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 98efe71421cf4..6ff2f9dc46c43 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -192,6 +192,7 @@ def pure_fp16_initialize(models): if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, paddle.incubate.nn.FusedMultiTransformer, + paddle.incubate.nn.FusedMultiTransformerMoe, paddle.incubate.nn.FusedMoELayer)): layer._amp_decorate(dtype='float16') continue diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 2a2def22bb3bf..c5ca8e38215d7 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -18,13 +18,15 @@ from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 -from .layer.fused_transformer import FusedMoELayer # tianyan01 add +from .layer.fused_transformer import FusedMoELayer # noqa: F401 +from .layer.fused_transformer import FusedMultiTransformerMoe # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', 'FusedFeedForward', 'FusedTransformerEncoderLayer', 'FusedMultiTransformer', + 'FusedMultiTransformerMoe', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', 'FusedMoELayer', diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index a39f1cb94c0c5..cd925f2a3df2e 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -847,6 +847,8 @@ def fused_multi_transformer( pre_layer_norm=True, epsilon=1e-05, cache_kvs=None, + beam_offset=None, + seq_lens=None, time_step=None, attn_mask=None, dropout_rate=0.0, @@ -1006,7 +1008,9 @@ def fused_multi_transformer( list(qkv_weights), list(qkv_biases), cache_kvs, + beam_offset, time_step, + seq_lens, attn_mask, list(linear_weights), list(linear_biases), @@ -1054,6 +1058,8 @@ def fused_multi_transformer( inputs['LnScale'] = list(ln_scales) inputs['LnBias'] = list(ln_biases) inputs['QKVW'] = list(qkv_weights) + if seq_lens is not None: + inputs['SeqLengths'] = seq_lens if qkv_biases is not None: inputs['QKVBias'] = list(qkv_biases) if cache_kvs is not None: @@ -1061,6 +1067,8 @@ def fused_multi_transformer( inputs['CacheKV'] = cache_kvs if time_step is not None: inputs['TimeStep'] = time_step + if beam_offset is not None: + inputs['BeamCacheOffset'] = beam_offset inputs['SrcMask'] = attn_mask inputs['OutLinearW'] = list(linear_weights) if linear_biases is not None: diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index b4f173fc725d3..d66097cdf51a1 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1379,7 +1379,7 @@ def get_attr(attrs, idx): self.activation = activation self.name = name - def forward(self, src, attn_mask=None, caches=None, time_step=None): + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): """ Applies multi transformer layers on the input. @@ -1429,7 +1429,9 @@ def forward(self, src, attn_mask=None, caches=None, time_step=None): pre_layer_norm=self.normalize_before, epsilon=self._epsilon, cache_kvs=caches, + beam_offset=beam_offset, time_step=time_step, + seq_lens=seq_lens, attn_mask=attn_mask, dropout_rate=self.dropout_rate, activation=self.activation, @@ -1514,6 +1516,7 @@ def __init__(self, self.mp_rank = mp_group.rank self.mp_size = mp_group.nranks self.d_model = d_model + self.dim_feedforward = dim_feedforward self.top_k = top_k self.approximate = approximate self.ln_scale = self.create_parameter( @@ -1586,7 +1589,10 @@ def get_attr(attrs, idx): self.linear2_biases[i].name = "expert_" + self.linear2_biases[i].name def forward(self, inp): - inp = _C_ops.fused_moe_kernel( + bsz = inp.shape[0] + seq_len = inp.shape[1] + out = _C_ops.fused_moe_kernel( + inp, inp, self.gate_weight, self.gate_bias, @@ -1604,9 +1610,13 @@ def forward(self, inp): self.num_expert, self.world_size, -1 if self.group is None else self.group.id, - self.approximate + self.approximate, + bsz, + seq_len, + self.d_model, + self.dim_feedforward ) - return inp + return out def _amp_decorate(self, dtype): # tmp fix for amp.decorator(O2) @@ -1622,3 +1632,329 @@ def trans_to_fp16(l): _ = _to_dtype(self.gate_weight, dtype) _ = _to_dtype(self.gate_bias, dtype) self._dtype = dtype + + +class FusedMultiTransformerMoe(Layer): + """ + FusedMultiTransformerMoe + """ + def __init__( + self, + d_model, + embed_dim, + num_heads, + dim_feedforward, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_bias_attrs=None, + gate_weight_attrs=None, + gate_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + expert_weight1_attrs=None, + expert_bias1_attrs=None, + expert_weight2_attrs=None, + expert_bias2_attrs=None, + epsilon=1e-5, + num_layers=-1, + nranks=1, + trans_qkvw=True, + ring_id=-1, + num_expert=1, + top_k=2, + approximate=True, + moe_group=None, + mp_group=None, + name=None, + ): + super(FusedMultiTransformerMoe, self).__init__() + assert embed_dim > 0, ( + "Expected embed_dim to be greater than 0, " + "but received {}".format(embed_dim) + ) + assert ( + num_heads > 0 + ), "Expected nhead to be greater than 0, " "but received {}".format( + num_heads + ) + assert ( + dim_feedforward > 0 + ), "Expected dim_feedforward to be greater than 0, but received {}".format( + dim_feedforward + ) + # only support mp/dp + # for moe config + self.group = moe_group + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = num_expert + + self.mp_rank = 0 + self.mp_size = 1 + if mp_group is not None and mp_group.nranks > 1: + self.mp_rank = mp_group.rank + self.mp_size = mp_group.nranks + self.top_k = top_k + self.approximate = approximate + + # origin fmt config + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._epsilon = epsilon + self._trans_qkvw = trans_qkvw + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + assert num_heads % nranks == 0 + num_heads = num_heads // nranks + + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): + num_layers = len(qkv_weight_attrs) + assert num_layers > 0 + + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.gate_weights, self.gate_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.expert_weights1, self.expert_biases1 = ParameterList(), ParameterList() + self.expert_weights2, self.expert_biases2 = ParameterList(), ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_layers + return attrs[idx] + return attrs + + for i in range(num_layers): + ln_scale_attr = get_attr(ln_scale_attrs, i) + ln_bias_attr = get_attr(ln_bias_attrs, i) + qkv_weight_attr = get_attr(qkv_weight_attrs, i) + qkv_bias_attr = get_attr(qkv_bias_attrs, i) + linear_weight_attr = get_attr(linear_weight_attrs, i) + linear_bias_attr = get_attr(linear_bias_attrs, i) + + ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) + gate_weight_attr = get_attr(gate_weight_attrs, i) + gate_bias_attr = get_attr(gate_bias_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0), + dtype="float32", + ) + ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" + ) + qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim] + if trans_qkvw + else [embed_dim, 3, num_heads, self.head_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + linear_weight = self.create_parameter( + shape=[num_heads * self.head_dim, embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + ffn_ln_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(1.0), + dtype="float32", + ) + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" + ) + gate_weight = self.create_parameter( + shape=[d_model, num_expert * self.world_size], + attr=gate_weight_attr, + dtype=self._dtype, + is_bias=False + ) + gate_bias = self.create_parameter( + shape=[num_expert * self.world_size], + attr=gate_bias_attr, + dtype=self._dtype, + is_bias=True + ) + + # tensor model parallel + if nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + # row parallel + _set_var_distributed(linear_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_biases.append(linear_bias) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.gate_weights.append(gate_weight) + self.gate_biases.append(gate_bias) + + for j in range(num_expert): + expert_weight1_attr = get_attr(expert_weight1_attrs, i * num_expert + j) + expert_bias1_attr = get_attr(expert_bias1_attrs, i * num_expert + j) + expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) + expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) + + expert_weight1 = self.create_parameter( + shape=[d_model, dim_feedforward], + attr=expert_weight1_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias1 = self.create_parameter( + shape=[dim_feedforward], + attr=expert_bias1_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight2 = self.create_parameter( + shape=[dim_feedforward, d_model], + attr=expert_weight2_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias2 = self.create_parameter( + shape=[d_model], + attr=expert_bias2_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight1.name = "expert_" + expert_weight1.name + expert_bias1.name = "expert_" + expert_bias1.name + expert_weight2.name = "expert_" + expert_weight2.name + expert_bias2.name = "expert_" + expert_bias2.name + self.expert_weights1.append(expert_weight1) + self.expert_biases1.append(expert_bias1) + self.expert_weights2.append(expert_weight2) + self.expert_biases2.append(expert_bias2) + self.dropout_rate = dropout_rate + self.activation = activation + self.name = name + + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): + """ + forward + """ + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_moe( + src, + list(self.ln_scales), + list(self.ln_biases), + list(self.qkv_weights), + list(self.qkv_biases), + caches, + beam_offset, + time_step, + seq_lens, + attn_mask, + list(self.linear_weights), + list(self.linear_biases), + list(self.gate_weights), + list(self.gate_biases), + list(self.ffn_ln_scales), + list(self.ffn_ln_biases), + list(self.expert_weights1), + list(self.expert_biases1), + list(self.expert_weights2), + list(self.expert_biases2), + caches, + 'pre_layer_norm', + self.normalize_before, + 'epsilon', + self._epsilon, + 'dropout_rate', + self.dropout_rate, + 'is_test', + not self.training, + 'dropout_implementation', + 'upscale_in_train', + 'act_method', + self.activation, + 'trans_qkvw', + self._trans_qkvw, + 'ring_id', + self._ring_id, + 'topk', + self.top_k, + 'mp_size', + self.mp_size, + 'mp_rank', + self.mp_rank, + 'num_expert', + self.num_expert, + 'world_size', + self.world_size, + 'moe_ring_id', + -1 if self.group is None else self.group.id, + 'approximate', + self.approximate + ) + if caches is not None: + return final_out, cache_kv_out + return final_out + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.qkv_weights) + trans_to_fp16(self.qkv_biases) + trans_to_fp16(self.linear_weights) + trans_to_fp16(self.linear_biases) + trans_to_fp16(self.gate_weights) + trans_to_fp16(self.gate_biases) + trans_to_fp16(self.expert_weights1) + trans_to_fp16(self.expert_biases1) + trans_to_fp16(self.expert_weights2) + trans_to_fp16(self.expert_biases2) + self._dtype = dtype \ No newline at end of file