From 6b5a155065ed12c7e587fe60f72aeef7a77bafb4 Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Thu, 12 Sep 2024 12:59:01 +0000 Subject: [PATCH 1/5] add top_p kernel --- csrc/gpu/sample_kernels/sampling.cuh | 329 ++++++++++++++++++ .../top_p_sampling_from_probs.cu | 77 ++++ csrc/gpu/sample_kernels/utils.cuh | 233 +++++++++++++ csrc/setup_cuda.py | 2 + .../transformers/generation_utils.py | 24 +- 5 files changed, 662 insertions(+), 3 deletions(-) create mode 100644 csrc/gpu/sample_kernels/sampling.cuh create mode 100644 csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu create mode 100644 csrc/gpu/sample_kernels/utils.cuh diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh new file mode 100644 index 000000000000..09b3820a0077 --- /dev/null +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -0,0 +1,329 @@ +#pragma once + +#include +#include +#include +#include +#include "utils.cuh" + + +namespace sampling { + +using namespace cub; + +constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; +constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) +#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED +#endif + +template +struct Pair { + T value; + int count; + + __device__ Pair operator+(const Pair& other) const { + return {value + other.value, count + other.count}; + } + __device__ Pair& operator+=(const Pair& other) { + value += other.value; + count += other.count; + return *this; + } +}; + +struct BoolDiffOp { + __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const { + return lhs != rhs; + } +}; + +template +struct SamplingTempStorage { + union { + T deterministic_scan[BLOCK_THREADS / 32]; + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage reduce; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; + typename BlockAdjacentDifference::TempStorage adj_diff; + } block_prim; + struct { + int32_t sampled_id; + union { + T value; + Pair pair; + T max_p; + } block_aggregate; + } data; +}; + +/*! + * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. + * \note This implementation is slower than the cub::BlockScan, but it is deterministic. + */ +template +__device__ __forceinline__ void DeterministicInclusiveSum( + const T* in_data, T* out_data, + SamplingTempStorage* temp_storage) { + T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + T thread_data[VEC_SIZE]; + T thread_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + thread_sum += in_data[i]; + thread_data[i] = thread_sum; + } + + T thread_exclusive_prefix_sum = thread_sum; + +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum += tmp; + } + } + + T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + if (threadIdx.x % 32 == 31) { + thread_exclusive_prefix_sum = 0; + } + +#pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + thread_exclusive_prefix_sum = tmp; + } + } + + smem_prefix_sum[threadIdx.x / 32] = warp_sum; + __syncthreads(); + + if (threadIdx.x < 32) { + T warp_exclusive_prefix_sum = + (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum += tmp; + } + } + + if (threadIdx.x % 32 == 31) { + warp_exclusive_prefix_sum = 0; + } + +#pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + warp_exclusive_prefix_sum = tmp; + } + } + if (threadIdx.x < BLOCK_THREADS / 32) { + smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; + } + } + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + out_data[i] = smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + } +} + +template +__device__ __forceinline__ void DeviceSamplingFromProb( + uint32_t i, uint32_t d, T threshold, T u, vec_t prob_vec, T& aggregate, + SamplingTempStorage* temp_storage) { + const uint32_t tx = threadIdx.x; + T prob_greater_than_threshold[VEC_SIZE]; + T inclusive_cdf[VEC_SIZE]; + bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + prob_greater_than_threshold[j] = (prob_vec[j] > threshold) ? prob_vec[j] : T(0); + valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; + } + T aggregate_local = + BlockReduce(temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); + if (tx == 0) { + temp_storage->data.block_aggregate.value = aggregate_local; + } + __syncthreads(); + aggregate_local = temp_storage->data.block_aggregate.value; + + if (aggregate + aggregate_local > u) { + if constexpr (DETERMINISTIC) { + DeterministicInclusiveSum( + prob_greater_than_threshold, inclusive_cdf, temp_storage); + } else { + BlockScan(temp_storage->block_prim.scan) + .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + + __syncthreads(); + } + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + greater_than_u[j] = inclusive_cdf[j] + aggregate > u; + } + + bool greater_than_u_diff[VEC_SIZE]; +#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); +#else + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); +#endif + __syncthreads(); + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + if (greater_than_u_diff[j] && valid[j]) { + if constexpr (DETERMINISTIC) { + temp_storage->data.sampled_id = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } else { + // cub's block scan result might not be monotonic, so we need to find the first element + atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + } + } + } + __syncthreads(); + } + aggregate += aggregate_local; +} + +template +__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, + bool* success, IdType* row_indices, float* top_p_arr, + float* top_p_val, uint32_t d, uint32_t max_top_p_rounds) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx]; + + const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>(smem_sampling); + + vec_t probs_vec; + DType aggregate; + DType q = DType(1); + DType pivot = DType(0); + IdType sampled_id; + for (uint32_t round = 0; round < max_top_p_rounds; ++round) { + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + DType u = uniform_samples[round * batch_size + bx] * q; + aggregate = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, + &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.data.sampled_id; + pivot = max(pivot, probs[row_idx * d + sampled_id]); + + DType aggregate_gt_pivot = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DType probs_gt_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); + } + + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot); + if (tx == 0) { + temp_storage.data.block_aggregate.value = aggregate_gt_pivot; + } + __syncthreads(); + } + q = temp_storage.data.block_aggregate.value; + if (float(q) < top_p) { + break; + } + } + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + if (float(q) >= top_p) { + // failed to sample within MAX_TOP_P_ROUNDS + if (success != nullptr) { + success[bx] = false; + } + } else { + if (success != nullptr) { + success[bx] = true; + } + } + } +} + + + +template +cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, + T* top_p_arr, uint32_t batch_size, const T* top_p_val, uint32_t d, + uint32_t max_top_p_rounds, bool deterministic, + cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + IdType* row_indices_placeholder = nullptr; + void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices_placeholder, + &top_p_arr, &top_p_val, &d, &max_top_p_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopPSamplingFromProbKernel; + CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; +} + +} // namespace sampling + + diff --git a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu new file mode 100644 index 000000000000..d22c604856ab --- /dev/null +++ b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sampling.cuh" +#include "helper.h" + +std::vector top_p_sampling_from_probs(const paddle::Tensor& probs, + const paddle::Tensor& uniform_samples, + const paddle::Tensor& top_p + ) { + + std::vector probs_shape = probs.shape(); + unsigned int batch_size = probs_shape[0]; + unsigned int vocab_size = probs_shape[1]; + std::vector uniform_samples_shape = uniform_samples.shape(); + PD_CHECK(uniform_samples_shape[0], batch_size); + unsigned int max_top_p_rounds = uniform_samples_shape[1]; + // todo: add parameter for deterministic, now default is true + bool deterministic = true; + paddle::Tensor probs_input; + paddle::Tensor uniform_samples_input; + + probs_input = paddle::experimental::cast(probs,paddle::DataType::FLOAT32); + uniform_samples_input =paddle::experimental::cast(uniform_samples, paddle::DataType::FLOAT32); + auto cu_stream = probs.stream(); + + auto samples = paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place()); + auto success = paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); + + cudaError_t status = sampling::TopPSamplingFromProb( + probs_input.data(), uniform_samples_input.data(), + samples.data(), success.data(), + nullptr, batch_size, top_p.data(), + vocab_size, max_top_p_rounds, deterministic, cu_stream); + PD_CHECK(status == cudaSuccess, + "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); + paddle::Tensor samples_output; + samples_output =paddle::experimental::cast(samples, paddle::DataType::INT64); + return {samples_output}; +} + +std::vector> top_p_sampling_from_probs_InferShape(const std::vector& probs_shape, + const std::vector& uniform_samples_shape, + const std::vector& top_p_shape + ) { + int64_t bs = probs_shape[0]; + return {{bs, 1}}; +} + +std::vector top_p_sampling_from_probs_InferDtype(const paddle::DataType& probs_dtype, + const paddle::DataType& uniform_samples_dtype, + const paddle::DataType& top_p_shape) +{ + return {probs_dtype}; +} + +PD_BUILD_OP(top_p_sampling_from_probs) + .Inputs({"probs", "uniform_samples", "top_p"}) + .Outputs({"samples"}) + .SetKernelFn(PD_KERNEL(top_p_sampling_from_probs)) + .SetInferShapeFn(PD_INFER_SHAPE(top_p_sampling_from_probs_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(top_p_sampling_from_probs_InferDtype)); + + + + diff --git a/csrc/gpu/sample_kernels/utils.cuh b/csrc/gpu/sample_kernels/utils.cuh new file mode 100644 index 000000000000..ae322b6e88f1 --- /dev/null +++ b/csrc/gpu/sample_kernels/utils.cuh @@ -0,0 +1,233 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +/******************* utils *******************/ +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +#ifndef NDEBUG +#define CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e << ") " << __FILE__ \ + << ": line " << __LINE__ << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } +#else +#define CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + return e; \ + } \ + } +#endif + +#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ + if (deterministic) { \ + constexpr bool DETERMINISTIC = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool DETERMINISTIC = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + throw std::invalid_argument(err_msg.str()); \ + } \ + } + +/******************* vec_t *******************/ + +#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__ +template +struct vec_t { + SAMPLING_INLINE float_t& operator[](size_t i); + SAMPLING_INLINE const float_t& operator[](size_t i) const; + SAMPLING_INLINE void fill(float_t val); + SAMPLING_INLINE void load(const float_t* ptr); + SAMPLING_INLINE void store(float_t* ptr) const; + template + SAMPLING_INLINE void cast_from(const vec_t& src); + template + SAMPLING_INLINE void cast_load(const T* ptr); + template + SAMPLING_INLINE void cast_store(T* ptr) const; + SAMPLING_INLINE static void memcpy(float_t* dst, const float_t* src); + SAMPLING_INLINE float_t* ptr(); +}; + +// float x 1 +template <> +struct vec_t { + float data; + + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE void fill(float val); + SAMPLING_INLINE void load(const float* ptr); + SAMPLING_INLINE void store(float* ptr) const; + template + SAMPLING_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SAMPLING_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SAMPLING_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + SAMPLING_INLINE static void memcpy(float* dst, const float* src); +}; + +SAMPLING_INLINE void vec_t::fill(float val) { data = val; } + +SAMPLING_INLINE void vec_t::load(const float* ptr) { data = *ptr; } + +SAMPLING_INLINE void vec_t::store(float* ptr) const { *ptr = data; } + +SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } + +// float x 2 +template <> +struct vec_t { + float2 data; + + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE void fill(float val); + SAMPLING_INLINE void load(const float* ptr); + SAMPLING_INLINE void store(float* ptr) const; + template + SAMPLING_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SAMPLING_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SAMPLING_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + SAMPLING_INLINE static void memcpy(float* dst, const float* src); +}; + +SAMPLING_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } + +SAMPLING_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } + +SAMPLING_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } + +SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + SAMPLING_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + SAMPLING_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + SAMPLING_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SAMPLING_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SAMPLING_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + SAMPLING_INLINE static void memcpy(float* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +inline std::pair GetCudaComputeCapability() { + int device_id = 0; + cudaGetDevice(&device_id); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id); + return std::make_pair(major, minor); +} + +/******************* math *******************/ +__forceinline__ __device__ float ptx_rcp(float x) { + float y; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +template +__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { + return (x + y - 1) / y; +} \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index a06eb6b9e760..269a37acded3 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -107,6 +107,7 @@ def get_gencode_flags(): "./gpu/dequant_int8.cu", "./gpu/flash_attn_bwd.cc", "./gpu/tune_cublaslt_gemm.cu", + "./gpu/sample_kernels/top_p_sampling_from_probs.cu", ] cutlass_dir = "third_party/cutlass" @@ -135,6 +136,7 @@ def get_gencode_flags(): "-Ithird_party/cutlass/include", "-Ithird_party/nlohmann_json/single_include", "-Igpu/fp8_gemm_with_cutlass", + "-Igpu/sample_kernels", "-Igpu", ] cc = get_sm_version() diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 04bfafb3167f..29409c1b2f39 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -329,8 +329,16 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): # sample probs = F.softmax(logits) - # compute next_tokens, use paddle.tensor.top_p_sampling - _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) + # compute next_tokens + try: + from paddlenlp_ops import top_p_sampling_from_probs + + # max_rounds default is 32 + bs = probs.shape[0] + uniform_samples = paddle.randn([bs, 32]) + next_tokens = top_p_sampling_from_probs(probs, uniform_samples, top_p) + except: + _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0) @@ -667,7 +675,17 @@ def _post_process_( # sample probs = F.softmax(logits) - _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) + + # compute next_tokens + try: + from paddlenlp_ops import top_p_sampling_from_probs + + bs = probs.shape[0] + # max_rounds default is 32 + uniform_samples = paddle.randn([bs, 32]) + next_tokens = top_p_sampling_from_probs(probs, uniform_samples, top_p) + except: + _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0) From beb982373cbdfa16faa92f90ff2cf624250d4f1d Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Thu, 12 Sep 2024 13:00:06 +0000 Subject: [PATCH 2/5] fix code style --- csrc/gpu/sample_kernels/sampling.cuh | 202 ++++++++++++------ .../top_p_sampling_from_probs.cu | 74 ++++--- csrc/gpu/sample_kernels/utils.cuh | 58 +++-- 3 files changed, 217 insertions(+), 117 deletions(-) diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 09b3820a0077..64049b57d96f 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -4,6 +4,7 @@ #include #include #include + #include "utils.cuh" @@ -34,19 +35,24 @@ struct Pair { }; struct BoolDiffOp { - __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const { + __device__ __forceinline__ bool operator()(const bool& lhs, + const bool& rhs) const { return lhs != rhs; } }; -template struct SamplingTempStorage { union { T deterministic_scan[BLOCK_THREADS / 32]; typename BlockScan::TempStorage scan; - typename BlockReduce::TempStorage reduce; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; + typename BlockReduce::TempStorage + reduce; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_pair; typename BlockAdjacentDifference::TempStorage adj_diff; } block_prim; struct { @@ -60,14 +66,20 @@ struct SamplingTempStorage { }; /*! - * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. - * \note This implementation is slower than the cub::BlockScan, but it is deterministic. + * \brief Deterministic inclusive scan implementation, use Belloch scan + * algorithm. \note This implementation is slower than the cub::BlockScan, but + * it is deterministic. */ -template +template __device__ __forceinline__ void DeterministicInclusiveSum( - const T* in_data, T* out_data, - SamplingTempStorage* temp_storage) { + const T* in_data, + T* out_data, + SamplingTempStorage* + temp_storage) { T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; T thread_data[VEC_SIZE]; T thread_sum = 0; @@ -87,7 +99,8 @@ __device__ __forceinline__ void DeterministicInclusiveSum( } } - T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + T warp_sum = __shfl_sync( + 0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); if (threadIdx.x % 32 == 31) { thread_exclusive_prefix_sum = 0; } @@ -140,27 +153,40 @@ __device__ __forceinline__ void DeterministicInclusiveSum( #pragma unroll for (uint32_t i = 0; i < VEC_SIZE; ++i) { - out_data[i] = smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + out_data[i] = smem_prefix_sum[threadIdx.x / 32] + + thread_exclusive_prefix_sum + thread_data[i]; } } -template +template __device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, uint32_t d, T threshold, T u, vec_t prob_vec, T& aggregate, - SamplingTempStorage* temp_storage) { + uint32_t i, + uint32_t d, + T threshold, + T u, + vec_t prob_vec, + T& aggregate, + SamplingTempStorage* + temp_storage) { const uint32_t tx = threadIdx.x; T prob_greater_than_threshold[VEC_SIZE]; T inclusive_cdf[VEC_SIZE]; bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - prob_greater_than_threshold[j] = (prob_vec[j] > threshold) ? prob_vec[j] : T(0); - valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; + prob_greater_than_threshold[j] = + (prob_vec[j] > threshold) ? prob_vec[j] : T(0); + valid[j] = + prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; } - T aggregate_local = - BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); + T aggregate_local = BlockReduce( + temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); if (tx == 0) { temp_storage->data.block_aggregate.value = aggregate_local; } @@ -169,7 +195,11 @@ __device__ __forceinline__ void DeviceSamplingFromProb( if (aggregate + aggregate_local > u) { if constexpr (DETERMINISTIC) { - DeterministicInclusiveSum( + DeterministicInclusiveSum( prob_greater_than_threshold, inclusive_cdf, temp_storage); } else { BlockScan(temp_storage->block_prim.scan) @@ -185,11 +215,15 @@ __device__ __forceinline__ void DeviceSamplingFromProb( bool greater_than_u_diff[VEC_SIZE]; #ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .SubtractLeft( + greater_than_u, greater_than_u_diff, BoolDiffOp()); #else - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .FlagHeads( + greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); #endif __syncthreads(); @@ -197,10 +231,13 @@ __device__ __forceinline__ void DeviceSamplingFromProb( for (uint32_t j = 0; j < VEC_SIZE; ++j) { if (greater_than_u_diff[j] && valid[j]) { if constexpr (DETERMINISTIC) { - temp_storage->data.sampled_id = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + temp_storage->data.sampled_id = + (i * BLOCK_THREADS + tx) * VEC_SIZE + j; } else { - // cub's block scan result might not be monotonic, so we need to find the first element - atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + // cub's block scan result might not be monotonic, so we need to find + // the first element + atomicMin(&(temp_storage->data.sampled_id), + (i * BLOCK_THREADS + tx) * VEC_SIZE + j); } } } @@ -209,23 +246,38 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } -template -__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, - bool* success, IdType* row_indices, float* top_p_arr, - float* top_p_val, uint32_t d, uint32_t max_top_p_rounds) { +template +__global__ void TopPSamplingFromProbKernel(DType* probs, + DType* uniform_samples, + IdType* output, + bool* success, + IdType* row_indices, + float* top_p_arr, + float* top_p_val, + uint32_t d, + uint32_t max_top_p_rounds) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx]; const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); + auto& temp_storage = + reinterpret_cast&>(smem_sampling); vec_t probs_vec; DType aggregate; @@ -240,12 +292,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, - &temp_storage); + DeviceSamplingFromProb( + i, d, pivot, u, probs_vec, aggregate, &temp_storage); if (aggregate > u) { break; } @@ -258,7 +315,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } DType probs_gt_pivot[VEC_SIZE]; @@ -267,8 +325,9 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); } - aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); + aggregate_gt_pivot += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot); if (tx == 0) { temp_storage.data.block_aggregate.value = aggregate_gt_pivot; } @@ -296,34 +355,53 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } - template -cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - T* top_p_arr, uint32_t batch_size, const T* top_p_val, uint32_t d, - uint32_t max_top_p_rounds, bool deterministic, +cudaError_t TopPSamplingFromProb(T* probs, + T* uniform_samples, + IdType* output, + bool* success, + T* top_p_arr, + uint32_t batch_size, + const T* top_p_val, + uint32_t d, + uint32_t max_top_p_rounds, + bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); IdType* row_indices_placeholder = nullptr; - void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices_placeholder, - &top_p_arr, &top_p_val, &d, &max_top_p_rounds}; + void* args[] = {&probs, + &uniform_samples, + &output, + &success, + &row_indices_placeholder, + &top_p_arr, + &top_p_val, + &d, + &max_top_p_rounds}; DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopPSamplingFromProbKernel; - CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + vec_size, + VEC_SIZE, + {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopPSamplingFromProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); return cudaSuccess; } } // namespace sampling - - diff --git a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu index d22c604856ab..121a9ff40603 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu @@ -1,68 +1,78 @@ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "sampling.cuh" #include "helper.h" +#include "sampling.cuh" -std::vector top_p_sampling_from_probs(const paddle::Tensor& probs, - const paddle::Tensor& uniform_samples, - const paddle::Tensor& top_p - ) { - +std::vector top_p_sampling_from_probs( + const paddle::Tensor& probs, + const paddle::Tensor& uniform_samples, + const paddle::Tensor& top_p) { std::vector probs_shape = probs.shape(); unsigned int batch_size = probs_shape[0]; unsigned int vocab_size = probs_shape[1]; std::vector uniform_samples_shape = uniform_samples.shape(); - PD_CHECK(uniform_samples_shape[0], batch_size); + PD_CHECK(uniform_samples_shape[0], batch_size); unsigned int max_top_p_rounds = uniform_samples_shape[1]; // todo: add parameter for deterministic, now default is true bool deterministic = true; paddle::Tensor probs_input; paddle::Tensor uniform_samples_input; - probs_input = paddle::experimental::cast(probs,paddle::DataType::FLOAT32); - uniform_samples_input =paddle::experimental::cast(uniform_samples, paddle::DataType::FLOAT32); + probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32); + uniform_samples_input = + paddle::experimental::cast(uniform_samples, paddle::DataType::FLOAT32); auto cu_stream = probs.stream(); - auto samples = paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place()); - auto success = paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); + auto samples = + paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place()); + auto success = + paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); cudaError_t status = sampling::TopPSamplingFromProb( - probs_input.data(), uniform_samples_input.data(), - samples.data(), success.data(), - nullptr, batch_size, top_p.data(), - vocab_size, max_top_p_rounds, deterministic, cu_stream); - PD_CHECK(status == cudaSuccess, - "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); + probs_input.data(), + uniform_samples_input.data(), + samples.data(), + success.data(), + nullptr, + batch_size, + top_p.data(), + vocab_size, + max_top_p_rounds, + deterministic, + cu_stream); + PD_CHECK(status == cudaSuccess, + "SamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); paddle::Tensor samples_output; - samples_output =paddle::experimental::cast(samples, paddle::DataType::INT64); + samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); return {samples_output}; } -std::vector> top_p_sampling_from_probs_InferShape(const std::vector& probs_shape, - const std::vector& uniform_samples_shape, - const std::vector& top_p_shape - ) { +std::vector> top_p_sampling_from_probs_InferShape( + const std::vector& probs_shape, + const std::vector& uniform_samples_shape, + const std::vector& top_p_shape) { int64_t bs = probs_shape[0]; return {{bs, 1}}; } -std::vector top_p_sampling_from_probs_InferDtype(const paddle::DataType& probs_dtype, - const paddle::DataType& uniform_samples_dtype, - const paddle::DataType& top_p_shape) -{ - return {probs_dtype}; +std::vector top_p_sampling_from_probs_InferDtype( + const paddle::DataType& probs_dtype, + const paddle::DataType& uniform_samples_dtype, + const paddle::DataType& top_p_shape) { + return {probs_dtype}; } PD_BUILD_OP(top_p_sampling_from_probs) @@ -71,7 +81,3 @@ PD_BUILD_OP(top_p_sampling_from_probs) .SetKernelFn(PD_KERNEL(top_p_sampling_from_probs)) .SetInferShapeFn(PD_INFER_SHAPE(top_p_sampling_from_probs_InferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(top_p_sampling_from_probs_InferDtype)); - - - - diff --git a/csrc/gpu/sample_kernels/utils.cuh b/csrc/gpu/sample_kernels/utils.cuh index ae322b6e88f1..b2ab5416f4b3 100644 --- a/csrc/gpu/sample_kernels/utils.cuh +++ b/csrc/gpu/sample_kernels/utils.cuh @@ -1,7 +1,8 @@ #pragma once -#include #include +#include + #include #include #include @@ -13,22 +14,23 @@ #define STR(x) STR_HELPER(x) #ifndef NDEBUG -#define CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e << ") " << __FILE__ \ - << ": line " << __LINE__ << " at function " << STR(func) << std::endl; \ - return e; \ - } \ +#define CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ } #else #define CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - return e; \ - } \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + return e; \ + } \ } #endif @@ -101,7 +103,9 @@ struct vec_t { float data; SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } - SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val); SAMPLING_INLINE void load(const float* ptr); @@ -127,7 +131,9 @@ SAMPLING_INLINE void vec_t::load(const float* ptr) { data = *ptr; } SAMPLING_INLINE void vec_t::store(float* ptr) const { *ptr = data; } -SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } +SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { + *dst = *src; +} // float x 2 template <> @@ -135,7 +141,9 @@ struct vec_t { float2 data; SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } - SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val); SAMPLING_INLINE void load(const float* ptr); @@ -155,11 +163,17 @@ struct vec_t { SAMPLING_INLINE static void memcpy(float* dst, const float* src); }; -SAMPLING_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } +SAMPLING_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} -SAMPLING_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } +SAMPLING_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); +} -SAMPLING_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } +SAMPLING_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; +} SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { *((float2*)dst) = *((float2*)src); @@ -171,7 +185,9 @@ struct vec_t { float4 data[vec_size / 4]; SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } - SAMPLING_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; + } SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val) { #pragma unroll From 10630ec549dd9aa1aaed9ac2fa7108d2fc7532f8 Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Sat, 14 Sep 2024 06:00:53 +0000 Subject: [PATCH 3/5] fix review --- csrc/gpu/sample_kernels/sampling.cuh | 19 +++++- .../top_p_sampling_from_probs.cu | 65 +++++++++---------- csrc/gpu/sample_kernels/utils.cuh | 19 +++++- .../transformers/generation_utils.py | 14 ++-- 4 files changed, 71 insertions(+), 46 deletions(-) diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 64049b57d96f..5c420f5dfb9f 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -1,3 +1,20 @@ +// Copyright © 2024 PaddlePaddle Name. All Rights Reserved. +// +// This code is partially inspired by and references the implementation found in FlashInfer. +// Specifically, the implementation of Top-p Sampling functionality in this code is inspired by the logic of FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs function. +// For more details on FlashInfer’s documentation, please refer to: https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html#flashinfer-sampling-top-p-sampling-from_probs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include @@ -5,7 +22,7 @@ #include #include -#include "utils.cuh" +#include "sample_kernels/utils.cuh" namespace sampling { diff --git a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu index 121a9ff40603..1e98b0a81cd5 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu @@ -13,26 +13,25 @@ // limitations under the License. #include "helper.h" -#include "sampling.cuh" +#include "sample_kernels/sampling.cuh" -std::vector top_p_sampling_from_probs( - const paddle::Tensor& probs, - const paddle::Tensor& uniform_samples, - const paddle::Tensor& top_p) { +std::vector TopPSamplingReject(const paddle::Tensor& probs, + const paddle::Tensor& top_p) { std::vector probs_shape = probs.shape(); unsigned int batch_size = probs_shape[0]; unsigned int vocab_size = probs_shape[1]; - std::vector uniform_samples_shape = uniform_samples.shape(); - PD_CHECK(uniform_samples_shape[0], batch_size); - unsigned int max_top_p_rounds = uniform_samples_shape[1]; + + // default is 32 + unsigned int max_top_p_rounds = 32; + std::vector uniform_samples_shape = {batch_size, max_top_p_rounds}; + paddle::Tensor uniform_samples = paddle::experimental::uniform( + uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place()); + // todo: add parameter for deterministic, now default is true bool deterministic = true; paddle::Tensor probs_input; - paddle::Tensor uniform_samples_input; probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32); - uniform_samples_input = - paddle::experimental::cast(uniform_samples, paddle::DataType::FLOAT32); auto cu_stream = probs.stream(); auto samples = @@ -40,44 +39,42 @@ std::vector top_p_sampling_from_probs( auto success = paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); - cudaError_t status = sampling::TopPSamplingFromProb( - probs_input.data(), - uniform_samples_input.data(), - samples.data(), - success.data(), - nullptr, - batch_size, - top_p.data(), - vocab_size, - max_top_p_rounds, - deterministic, - cu_stream); + cudaError_t status = + sampling::TopPSamplingFromProb(probs_input.data(), + uniform_samples.data(), + samples.data(), + success.data(), + nullptr, + batch_size, + top_p.data(), + vocab_size, + max_top_p_rounds, + deterministic, + cu_stream); PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); + paddle::Tensor samples_output; samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); return {samples_output}; } -std::vector> top_p_sampling_from_probs_InferShape( +std::vector> TopPSamplingRejectInferShape( const std::vector& probs_shape, - const std::vector& uniform_samples_shape, const std::vector& top_p_shape) { int64_t bs = probs_shape[0]; return {{bs, 1}}; } -std::vector top_p_sampling_from_probs_InferDtype( - const paddle::DataType& probs_dtype, - const paddle::DataType& uniform_samples_dtype, - const paddle::DataType& top_p_shape) { +std::vector TopPSamplingRejectInferDtype( + const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) { return {probs_dtype}; } -PD_BUILD_OP(top_p_sampling_from_probs) - .Inputs({"probs", "uniform_samples", "top_p"}) +PD_BUILD_OP(top_p_sampling_reject) + .Inputs({"probs", "top_p"}) .Outputs({"samples"}) - .SetKernelFn(PD_KERNEL(top_p_sampling_from_probs)) - .SetInferShapeFn(PD_INFER_SHAPE(top_p_sampling_from_probs_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(top_p_sampling_from_probs_InferDtype)); + .SetKernelFn(PD_KERNEL(TopPSamplingReject)) + .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype)); diff --git a/csrc/gpu/sample_kernels/utils.cuh b/csrc/gpu/sample_kernels/utils.cuh index b2ab5416f4b3..39b93cf161b9 100644 --- a/csrc/gpu/sample_kernels/utils.cuh +++ b/csrc/gpu/sample_kernels/utils.cuh @@ -1,3 +1,20 @@ +// Copyright © 2024 PaddlePaddle Name. All Rights Reserved. +// +// This code is partially inspired by and references the implementation found in FlashInfer. +// Specifically, the implementation of Top-p Sampling functionality in this code is inspired by the logic of FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs function. +// For more details on FlashInfer’s documentation, please refer to: https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html#flashinfer-sampling-top-p-sampling-from_probs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include @@ -246,4 +263,4 @@ __forceinline__ __device__ float ptx_rcp(float x) { template __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { return (x + y - 1) / y; -} \ No newline at end of file +} diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 29409c1b2f39..5133c944cdf6 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -331,12 +331,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): # compute next_tokens try: - from paddlenlp_ops import top_p_sampling_from_probs + from paddlenlp_ops import top_p_sampling_reject - # max_rounds default is 32 - bs = probs.shape[0] - uniform_samples = paddle.randn([bs, 32]) - next_tokens = top_p_sampling_from_probs(probs, uniform_samples, top_p) + next_tokens = top_p_sampling_reject(probs, top_p) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) @@ -678,12 +675,9 @@ def _post_process_( # compute next_tokens try: - from paddlenlp_ops import top_p_sampling_from_probs + from paddlenlp_ops import top_p_sampling_reject - bs = probs.shape[0] - # max_rounds default is 32 - uniform_samples = paddle.randn([bs, 32]) - next_tokens = top_p_sampling_from_probs(probs, uniform_samples, top_p) + next_tokens = top_p_sampling_reject(probs, top_p) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) From 934f536d720ea04e103f6e8ab5d714cb2a39dd7d Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Sat, 14 Sep 2024 10:29:20 +0000 Subject: [PATCH 4/5] fix --- .../{top_p_sampling_from_probs.cu => top_p_sampling_reject.cu} | 0 csrc/setup_cuda.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename csrc/gpu/sample_kernels/{top_p_sampling_from_probs.cu => top_p_sampling_reject.cu} (100%) diff --git a/csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu similarity index 100% rename from csrc/gpu/sample_kernels/top_p_sampling_from_probs.cu rename to csrc/gpu/sample_kernels/top_p_sampling_reject.cu diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 269a37acded3..a2e75fb465f5 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -107,7 +107,7 @@ def get_gencode_flags(): "./gpu/dequant_int8.cu", "./gpu/flash_attn_bwd.cc", "./gpu/tune_cublaslt_gemm.cu", - "./gpu/sample_kernels/top_p_sampling_from_probs.cu", + "./gpu/sample_kernels/top_p_sampling_reject", ] cutlass_dir = "third_party/cutlass" From 2633f4e27cc0779ecdc493674a3dcf72d59dbcfa Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Wed, 18 Sep 2024 06:32:38 +0000 Subject: [PATCH 5/5] fix review --- csrc/gpu/sample_kernels/sampling.cuh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 5c420f5dfb9f..334747dd0126 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -1,8 +1,4 @@ -// Copyright © 2024 PaddlePaddle Name. All Rights Reserved. -// -// This code is partially inspired by and references the implementation found in FlashInfer. -// Specifically, the implementation of Top-p Sampling functionality in this code is inspired by the logic of FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs function. -// For more details on FlashInfer’s documentation, please refer to: https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html#flashinfer-sampling-top-p-sampling-from_probs +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +// This code is partially inspired by and references the implementation found +// in FlashInfer.Specifically, the implementation of Top-p Sampling functionality +// in this code is inspired by the logic of +// FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs . +// For more details on FlashInfer’s documentation, please refer to: +// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html + #pragma once #include