Skip to content

Commit

Permalink
Use data_ptr template instead of force data conversion (#4558)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonepia authored Jul 26, 2024
1 parent e4938e0 commit eeb92d2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions csrc/gpu/aten/operators/ROIAlign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ at::Tensor roi_align_forward_kernel(
[&] {
auto spatial_scale_ = static_cast<scalar_t>(spatial_scale);
auto cgf = DPCPP_Q_CGF(cgh) {
auto input_ptr = (scalar_t*)input_.data_ptr();
auto rois_ptr = (scalar_t*)rois_.data_ptr();
auto output_ptr = (scalar_t*)output.data_ptr();
auto input_ptr = input_.data_ptr<scalar_t>();
auto rois_ptr = rois_.data_ptr<scalar_t>();
auto output_ptr = output.data_ptr<scalar_t>();
ROIAlignForwardKernelFunctor<scalar_t> kfn(
output_size,
input_ptr,
Expand Down Expand Up @@ -592,9 +592,9 @@ at::Tensor roi_align_backward_kernel(
[&] {
auto spatial_scale_ = static_cast<scalar_t>(spatial_scale);
auto cgf = DPCPP_Q_CGF(cgh) {
auto grad_ptr = (scalar_t*)grad.data_ptr();
auto grad_input_ptr = (scalar_t*)grad_input.data_ptr();
auto rois_ptr = (scalar_t*)rois_.data_ptr();
auto grad_ptr = grad.data_ptr<scalar_t>();
auto grad_input_ptr = grad_input.data_ptr<scalar_t>();
auto rois_ptr = rois_.data_ptr<scalar_t>();
auto grad_numel = grad.numel();
ROIAlignBackwardKernelFunctor<scalar_t> kfn(
grad_numel,
Expand Down

0 comments on commit eeb92d2

Please sign in to comment.