Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sdpa dino v2 #33403

Merged
merged 10 commits into from
Sep 21, 2024
Merged

Sdpa dino v2 #33403

merged 10 commits into from
Sep 21, 2024

Conversation

avishaiElmakies
Copy link
Contributor

adds SDPA to dinov2

copy from VIT

address #28005

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts @fxmarty

notes

  • seems to fail some tests about output_attentions=True. but sdpa doesn't support it. when i added to OPT, if output_attentions==True then you send to eager mode. but ViT didn't have it. so i didn't add it here. can add if needed

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Before merge we'll need to run the slow model tests. Could you push an empty commit with the message [run-slow] dinov2?

The currently failing tests will also need to be resolved

docs/source/en/perf_infer_gpu_one.md Outdated Show resolved Hide resolved
@avishaiElmakies
Copy link
Contributor Author

@amyeroberts how should i resolve the failing tests?

@amyeroberts
Copy link
Collaborator

@avishaiElmakies You'll need to look at the CI logs and debug to figure out the problem based on the error message. First step is to make sure you can replicate the errors locally. You can also inspect other PRs which added SDPA e.g. #30555 to see if there's a common pattern of fixes which need to be applied

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@avishaiElmakies
Copy link
Contributor Author

avishaiElmakies commented Sep 10, 2024

@amyeroberts thanks for the reply
I would love some guidance if possible.

from what i understand, those are the tests that are failing:

FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_attention_outputs - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_retain_grad_hidden_states_attentions - AttributeError: 'NoneType' object has no attribute 'retain_grad'
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_equivalence_flax_to_pt - AssertionError: False is not true : outputs.attentions_0: `pt_outputs` should a tensor when `fx_outputs` is
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_equivalence_pt_to_flax - AssertionError: False is not true : outputs.attentions_0: `pt_outputs` should a tensor when `fx_outputs` is

from what I understand, those are related to the fact that sdpa doesn't return output_attentions. so i would love to know how to handle those. if output_attentions == True should i send them to eager?

FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_eager_matches_sdpa_inference_0_float16 - AttributeError: 'Dinov2ModelTester' object has no attribute 'num_masks'
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AttributeError: 'Dinov2ModelTester' object has no attribute 'num_masks'
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelTest::test_eager_matches_sdpa_inference_2_float32 - AttributeError: 'Dinov2ModelTester' object has no attribute 'num_masks'
FAILED tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelIntegrationTest::test_inference_no_head - AssertionError: False is not true

Those are very weird to me. I am able to replicate them, but something is weird. It seems someone changed the Dinov2Embeddings class, and it is different from the main branch. When I use the DinoV2Embedding from the main branch, I pass the tests. I can't really find the culprit or the relevant pull request to understand. so I would love some guidance as well

@amyeroberts amyeroberts mentioned this pull request Sep 10, 2024
5 tasks
@avishaiElmakies
Copy link
Contributor Author

@amyeroberts
would love some guidance here.
about the first 4 tests, I couldn't really find any PR that talks about this situation. What i have seen is that models (e.g. llama) send to eager version if output_attentions == True
Would also love some guidance about the other 4 tests.

@amyeroberts
Copy link
Collaborator

For the first four tests - refer to the linked PR and the adaptations to the tests made there. You'll see that "eager" is indeed selected. The diffs will show how this should be done.

For the other tests -- re the comment "It seems someone changed the Dinov2Embeddings class, and it is different from the main branch.". The version of DinoV2 on this branch will depend on which commit from main you branch off from. You can see all the changes to the dino v2 model here: https://github.com/huggingface/transformers/commits/ce62a41880b5b70a304d068eb58f55894a5a7af8/src/transformers/models/dinov2/modeling_dinov2.py.

Any PR should be in sync and up-to-date with the main branch. To make sure this is the case, you'll need to do frequent merges from main into this branch or rebases of this branch onto main.

If the branches are up-to-date and there still remains a differences between tests passing on main and this branch then it means the changes in this PR are causing the tests to fail and you'll need to debug to investigate this. Here's a handy guide on debugging with pytest: https://docs.pytest.org/en/stable/how-to/failures.html

@avishaiElmakies
Copy link
Contributor Author

avishaiElmakies commented Sep 18, 2024

@amyeroberts fixed the 7 tests (at least on my machine). test_inference_no_head fails in the main branch as well

@amyeroberts
Copy link
Collaborator

@avishaiElmakies Thanks for running the tests and adding this feature! Can confirm tests/models/dinov2/test_modeling_dinov2.py::Dinov2ModelIntegrationTest::test_inference_no_head fails on main so good to merge!

@amyeroberts amyeroberts merged commit 78b2929 into huggingface:main Sep 21, 2024
15 of 17 checks passed
@lezhang7
Copy link

@avishaiElmakies Thanks for you good contribution, this new features help us a lot! BTW, could you share an example or update doc on how to use this new features?

@avishaiElmakies
Copy link
Contributor Author

@amyeroberts great to hear, thanks for the guidance!

@lezhang7 I don't understand, this is sdpa in dinov2

@NielsRogge
Copy link
Contributor

@avishaiElmakies avishaiElmakies deleted the sdpa_dinoV2 branch September 21, 2024 08:39
@lezhang7
Copy link

Thank you! @NielsRogge , I use following for loading the model:

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
model = AutoModel.from_pretrained('facebook/dinov2-large',attn_implementation="sdpa")

However, I didn't find accelerated inference speed compared to original code without explicitly flag spda, is there anything wrong with my code?

@avishaiElmakies
Copy link
Contributor Author

I believe that after adding sdpa implementation, the default implementation is sdpa(at least for dinov2). So you don't need to use the flag to get sdpa implementation. You should compare

attn_implementation="eager" vs attn_implementation="sdpa"

amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
* add sdpa to dinov2

* fixup

* add dinov2 to sdpa doc

* update doc order

* [run-slow] dinov2

* common to eager

* [run-slow] dinov2

* update attn implementation in common

* update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit

* [run-slow] dinov2

---------

Co-authored-by: Avishai Elmakies <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants