From eeb92d2f4c34f143fc76e409987543d42e68d065 Mon Sep 17 00:00:00 2001 From: Stonepia Date: Fri, 26 Jul 2024 17:36:33 +0800 Subject: [PATCH] Use data_ptr template instead of force data conversion (#4558) --- csrc/gpu/aten/operators/ROIAlign.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/gpu/aten/operators/ROIAlign.cpp b/csrc/gpu/aten/operators/ROIAlign.cpp index b8b459017..9a08ce414 100644 --- a/csrc/gpu/aten/operators/ROIAlign.cpp +++ b/csrc/gpu/aten/operators/ROIAlign.cpp @@ -442,9 +442,9 @@ at::Tensor roi_align_forward_kernel( [&] { auto spatial_scale_ = static_cast(spatial_scale); auto cgf = DPCPP_Q_CGF(cgh) { - auto input_ptr = (scalar_t*)input_.data_ptr(); - auto rois_ptr = (scalar_t*)rois_.data_ptr(); - auto output_ptr = (scalar_t*)output.data_ptr(); + auto input_ptr = input_.data_ptr(); + auto rois_ptr = rois_.data_ptr(); + auto output_ptr = output.data_ptr(); ROIAlignForwardKernelFunctor kfn( output_size, input_ptr, @@ -592,9 +592,9 @@ at::Tensor roi_align_backward_kernel( [&] { auto spatial_scale_ = static_cast(spatial_scale); auto cgf = DPCPP_Q_CGF(cgh) { - auto grad_ptr = (scalar_t*)grad.data_ptr(); - auto grad_input_ptr = (scalar_t*)grad_input.data_ptr(); - auto rois_ptr = (scalar_t*)rois_.data_ptr(); + auto grad_ptr = grad.data_ptr(); + auto grad_input_ptr = grad_input.data_ptr(); + auto rois_ptr = rois_.data_ptr(); auto grad_numel = grad.numel(); ROIAlignBackwardKernelFunctor kfn( grad_numel,