Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dpt model tests regress after PR #33403 #33649

Closed
dvrogozh opened this issue Sep 22, 2024 · 8 comments · Fixed by #33660
Closed

dpt model tests regress after PR #33403 #33649

dvrogozh opened this issue Sep 22, 2024 · 8 comments · Fixed by #33660

Comments

@dvrogozh
Copy link
Contributor

Folllow up from #33485 (comment). On 78b2929, there is a regression after merging this PR:

On these 2 tests:

  • tests/models/dpt/test_modeling_dpt_auto_backbone.py::DPTModelTest::test_attention_outputs
  • tests/models/dpt/test_modeling_dpt_auto_backbone.py::DPTModelTest::test_retain_grad_hidden_states_attentions

CC: @avishaiElmakies, @amyeroberts

Example output:

$ python3 -m pytest tests/models/dpt/test_modeling_dpt_auto_backbone.py::DPTModelTest::test_retain_grad_hidden_states_attentions
=================================================== test session starts ====================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/dvrogozh/git/huggingface/transformers
configfile: pyproject.toml
plugins: pspec-0.0.4, timeout-2.3.1, hypothesis-6.112.1, xdist-3.6.1, rich-0.1.1, dash-2.18.1, cov-5.0.0, typeguard-4.3.0
collected 1 item

tests/models/dpt/test_modeling_dpt_auto_backbone.py F                                                                [100%]

========================================================= FAILURES =========================================================
__________________________________ DPTModelTest.test_retain_grad_hidden_states_attentions __________________________________

self = <tests.models.dpt.test_modeling_dpt_auto_backbone.DPTModelTest testMethod=test_retain_grad_hidden_states_attentions>

    def test_retain_grad_hidden_states_attentions(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
        config.output_attentions = self.has_attentions

        # no need to test all models as different heads yield the same functionality
        model_class = self.all_model_classes[0]
        model = model_class(config)
        model.to(torch_device)

        inputs = self._prepare_for_class(inputs_dict, model_class)

        outputs = model(**inputs)

        output = outputs[0]

        if config.is_encoder_decoder:
            # Seq2Seq models
            encoder_hidden_states = outputs.encoder_hidden_states[0]
            encoder_hidden_states.retain_grad()

            decoder_hidden_states = outputs.decoder_hidden_states[0]
            decoder_hidden_states.retain_grad()

            if self.has_attentions:
                encoder_attentions = outputs.encoder_attentions[0]
                encoder_attentions.retain_grad()

                decoder_attentions = outputs.decoder_attentions[0]
                decoder_attentions.retain_grad()

                cross_attentions = outputs.cross_attentions[0]
                cross_attentions.retain_grad()

            output.flatten()[0].backward(retain_graph=True)

            self.assertIsNotNone(encoder_hidden_states.grad)
            self.assertIsNotNone(decoder_hidden_states.grad)

            if self.has_attentions:
                self.assertIsNotNone(encoder_attentions.grad)
                self.assertIsNotNone(decoder_attentions.grad)
                self.assertIsNotNone(cross_attentions.grad)
        else:
            # Encoder-/Decoder-only models
            hidden_states = outputs.hidden_states[0]
            hidden_states.retain_grad()

            if self.has_attentions:
                attentions = outputs.attentions[0]
>               attentions.retain_grad()
E               AttributeError: 'NoneType' object has no attribute 'retain_grad'

tests/test_modeling_common.py:1677: AttributeError
===================================================== warnings summary =====================================================
src/transformers/deepspeed.py:24
  /home/dvrogozh/git/huggingface/transformers/src/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================= short test summary info ==================================================
FAILED tests/models/dpt/test_modeling_dpt_auto_backbone.py::
    Here we also overwrite some of the tests of test_modeling_common.py, as DPT does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    ::test_retain_grad_hidden_states_attentions - AttributeError: 'NoneType' object has no attribute 'retain_grad'
=============================================== 1 failed, 1 warning in 2.21s ===============================================

Bisect:

$ git bisect log
git bisect start
# bad: [78b2929c0554b79e0489b451ce4ece14d265ead2] Sdpa dino v2 (#33403)
git bisect bad 78b2929c0554b79e0489b451ce4ece14d265ead2
# good: [b50ff5993a5d8b2a3d8c7558e81684f8803b044a] [`Mamba2`] Move dt calculations to kernel (#33520)
git bisect good b50ff5993a5d8b2a3d8c7558e81684f8803b044a
# good: [653eb40425344b89b5a24e7b07eb3095b04cdc9d] Add sdpa for BioGpt (#33592)
git bisect good 653eb40425344b89b5a24e7b07eb3095b04cdc9d
# good: [077b552f0780c678737700184c109066736ece41] Fix some missing tests in circleci (#33559)
git bisect good 077b552f0780c678737700184c109066736ece41
# good: [7b2b536a811c84831e2c67eb388872b7c83a8263] Fix typos (#33583)
git bisect good 7b2b536a811c84831e2c67eb388872b7c83a8263
# good: [e472e077c24d6f6f080f5535f01c48f09164ec62] Granitemoe (#33207)
git bisect good e472e077c24d6f6f080f5535f01c48f09164ec62
# good: [e71bf70e33d501810951f353f1734cb5be74b32a] Pixtral update example checkpoint (#33633)
git bisect good e71bf70e33d501810951f353f1734cb5be74b32a
# first bad commit: [78b2929c0554b79e0489b451ce4ece14d265ead2] Sdpa dino v2 (#33403)
@vasqu
Copy link
Contributor

vasqu commented Sep 22, 2024

Both are failing due to SDPA not supporting the output of the attention weights. Afair, in text based models there's a fallback to eager attention (with a warning) if the output of attentions is requested (or other unsupported operations).

Can be fixed in two ways:

  • Fallback to eager as often done in text models, llama ref:
    if output_attentions:
    # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
    logger.warning_once(
    "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
    'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
    )
    return super().forward(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    cache_position=cache_position,
    position_embeddings=position_embeddings,
    )
  • Adjust tests to fallback to eager implementation, e.g. by overwriting.

Depends on preferences ig, although in the future the base test should be adjusted instead imo.

@avishaiElmakies
Copy link
Contributor

@dvrogozh why is dpt affected by my changes to dinov2?

@avishaiElmakies
Copy link
Contributor

avishaiElmakies commented Sep 22, 2024

@vasqu there is no sdpa implementation for DPT as far as i understand

@dvrogozh the tests also seem to pass on my machine in the latest main

@vasqu
Copy link
Contributor

vasqu commented Sep 22, 2024

@avishaiElmakies It's a special case where dpt uses dinov2 as a backbone, see for example:

def get_backbone_config(self):
return Dinov2Config(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
is_training=self.is_training,
out_features=self.out_features,
reshape_hidden_states=self.reshape_hidden_states,
)

Hence, the hidden failures as the dependency wasn't obvious.

@vasqu
Copy link
Contributor

vasqu commented Sep 22, 2024

@avishaiElmakies
Copy link
Contributor

@vasqu
OK, didn't know that

should I fix that?

@vasqu
Copy link
Contributor

vasqu commented Sep 22, 2024

If you have time, gladly :) @avishaiElmakies

@avishaiElmakies
Copy link
Contributor

I feel like it is kinda my fault so I will try to handle it @vasqu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants