diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu index a9bf01d56ac4c6..0cd34f5df8b7dc 100644 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +++ b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu @@ -28,6 +28,8 @@ at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &attn_weight, const int im2col_step) { + at::DeviceGuard guard(value.device()); + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); @@ -92,6 +94,7 @@ std::vector ms_deform_attn_cuda_backward( const at::Tensor &grad_output, const int im2col_step) { + at::DeviceGuard guard(value.device()); AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");