Skip to content

Commit

Permalink
[Paddle] Update type names for Paddle 3.0 (#1286)
Browse files Browse the repository at this point in the history
Update class names for Paddle 3.0

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Oct 24, 2024
1 parent 18c2234 commit 7a5fd0c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions transformer_engine/paddle/csrc/custom_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,14 @@ void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_p
auto state_index = gen_cuda->GetStateIndex();

auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
rng_elts_per_thread](phi::backends::gpu::gpuKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};

phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
Expand Down Expand Up @@ -1016,14 +1016,14 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
#if PADDLE_VERSION > 261
auto state_index = gen_cuda->GetStateIndex();
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
rng_elts_per_thread](phi::backends::gpu::gpuKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};

phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
Expand Down Expand Up @@ -1383,7 +1383,7 @@ void amax_and_scale_update_inplace_legacy(
const int *current_step_id_ptr =
reinterpret_cast<const int *>(GetOptionalDataPtr(current_step_id_tensor));
auto parameterSetter = [current_step_id_ptr,
fwd_update](phi::backends::gpu::CUDAKernelParams &params) {
fwd_update](phi::backends::gpu::gpuKernelParams &params) {
if (fwd_update) {
int current_step_id = *current_step_id_ptr;
params.As<bool>(7) = (current_step_id == 0);
Expand All @@ -1397,7 +1397,7 @@ void amax_and_scale_update_inplace_legacy(
float *scale_ptr = scale.data<float>();
float *scale_inv_ptr = scale_inv.data<float>();

phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&UpdateFP8MetaKernel);
cudaFunction_t cudaFunc;
Expand Down

0 comments on commit 7a5fd0c

Please sign in to comment.