Skip to content

Commit

Permalink
DeviceGuard added to use Deformable Attention more safely on multi-GPU (
Browse files Browse the repository at this point in the history
huggingface#32910)

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update ms_deform_attn_cuda.cu

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* [empty] this is a empty commit

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and zucchini-nlp committed Aug 30, 2024
1 parent 773862a commit 91a8183
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
at::DeviceGuard guard(value.device());

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
Expand Down Expand Up @@ -92,6 +94,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
at::DeviceGuard guard(value.device());

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
Expand Down

0 comments on commit 91a8183

Please sign in to comment.