Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax Dinov2 #31960

Merged
merged 52 commits into from
Aug 19, 2024
Merged

Add Flax Dinov2 #31960

merged 52 commits into from
Aug 19, 2024

Conversation

MHRDYN7
Copy link
Contributor

@MHRDYN7 MHRDYN7 commented Jul 14, 2024

This PR adds the Flax implementation of Dinov2, which seems to have been due since #25579

All the components of the pytorch Dinov2 model can be converted to flax except "interpolate_pos_encoding" which uses torch.nn.functional.interpolate. The closest jax function to replicate this is jax.image.scale_and_translate, however there seems to be a slight difference between these functions in the "Bicubic" mode (https://github.com/google/jax/issues/15768). 

In Dinov2, the pretrained weights of the position encoding are for the image size of 512, but we load images of size 224 into the model. The interpolate function acts to convert the shapes of the position encoding according to the size of the input images. The ViT model does have this interpolate function, but it's not there in the FlaxViT implementation as the config and input image sizes are the same. 

For now, I have directly loaded the pos_encoding weights from the pt model to flax, right after interpolation (which is saved in a safetensors file). This passes all the tests (including the two new integration tests added on top of the FlaxViT tests). Surely, this brute force approach to loading the original interpolated pos_encodings will not work, but otherwise, the slight deviations from jax scale_and _translate will fail the tests. @amyeroberts @sanchit-gandhi 

Other Remaining Tasks:

  1. Add the flax weights in .msgpack files to hub
  2. Test the SwiGLUFFN dense layer for vit giant

@amyeroberts
Copy link
Collaborator

Hi @MHRDYN7, thanks for working on this conversion!

The interpolate logic is tricky. If loading the weights directly at the moment means this model passes, then that's a good guide the conversion is OK. We might have to do something where we remove this before merge and skip the equivalence tests.

In the meantime, the first thing to do is get the other tests passing. Some of the failing tests are unrelated and have fixes upstream. Could you rebase on main to include these? For the quality checks, running make fixup should resolve

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Jul 24, 2024

Hi @amyeroberts, I did try make fixup and I'm not really sure why the two tests are still failing. Moreover, what should be done for skipping the equivalence test? Should I just remove the directly loaded tensors after interpolate and change the "expected_slice" tensor in the integration tests accordingly to make them pass?

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jul 26, 2024

@MHRDYN7 For the quality checks, you'll need to run make fix-copies and then possibly make fixup afterwards (I would just to be safe). You can see this from the CI logs.

Moreover, what should be done for skipping the equivalence test?

For this, in general, we would add a @unittest.skip decorator on the test. Typically, this means overriding the test in the specific model's tests and then adding the decorator e.g. like here

For your proposal re weights, this might be a good option as we'd still be checking the rest of the model. When you say "remove", I'm guessing you mean from the respective state dicts?

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Jul 29, 2024

@amyeroberts thank you. I have tried to solve all the issues. To summarize, I have finally decided to keep the jax.image.resize jax.image.scale_and_translate layer as it is, even though it results in slightly different output tensors compared to the pytorch model. This is because: 1. the outputs are still close enough to correctly predict the class of the input image and 2. the implementation will be exact once jax updates the resize layer to follow the pytorch conventions in the future. As a result, there was no need to skip any tests. In addition, a PR has been opened on the hub to add the flax weights. Please let me know if anything is out of order.

@amyeroberts
Copy link
Collaborator

@MHRDYN7 Great!

  1. the implementation will be exact once jax updates the resize layer to follow the pytorch conventions in the future.

Is there a plan for this to happen in the future?

In addition, a PR has been opened on the hub to add the flax weights.

This shouldn't be necessary, all the model frameworks: TF, PyTorch, Flax should be able to load the safetensors file.

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Jul 30, 2024

Is there a plan for this to happen in the future?

The issue from the jax repo mentioned on my first comment, suggests that they did come up with the fix but no steps were taken and also there are no PRs related to the issue. I might just open a PR there if I can; shouldn't be hard to solve.

This shouldn't be necessary, all the model frameworks: TF, PyTorch, Flax should be able to load the safetensors file.

It's good to hear that. Indeed, the models don't necessarily need the .msgpack weights. I observed the flax weights on the hub for many models and thought it was a convention to add those.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Mostly just some nits. Overall LGTM @sanchit-gandhi could you give. quick once-over to confirm flax is OK?

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
new_height_ratio = jnp.float32(height / math.sqrt(num_positions)) # ? 16/37
new_width_ratio = jnp.float32(width / math.sqrt(num_positions)) # ? 16/37

# patch_pos_embed = jax.image.resize(patch_pos_embed, shape=(hidden_states.shape[0], dim, height, width), method='bicubic', antialias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that I mistakenly wrote in my last comment that I used jax.image.resize whilst I actually used jax.image.scale_and_translate. Both these functions ultimately call the same helper function internally and therefore both of these could be used for interpolating the tensor. The reason why scale_and_translate() is the better fit is that it allows us to set the scale argument (which is key according to the original Dinov2 repo) while resize() determines the scale on its own. I'll remove the commented out line of code

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Aug 1, 2024

@amyeroberts thanks a lot for the review. All the tests are passing again.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR generally looks in good shape! Well done on handling all of the weight initialisations carefully @MHRDYN7 and porting the new functions over to Flax.

The main request from my review is using # Copied from statements as much as possible. There are many modules / methods that are copied 1-for-1 from existing models in the library. Here, prepending them with a # Copied from helps:

  1. Keep code sync'd across models
  2. The reviewer pinpoint which parts of the code to focus on!

Regarding your PR description: I didn't fully understand what the issue was with the position embedding weights - you've defined them as a standard self.param, and the keys look to match those from PyTorch? Let me know if I'm missing something here!

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
)


class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would copy this from Beit

Suggested change
class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel):
# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitPreTrainedModel with Beit-> Dinov2, beit -> dinov2

# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

params_rng, dropout_rng = jax.random.split(rng)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing the rng for the droppath - copying from Beit is going to fix this

tests/models/dinov2/test_modeling_flax_dinov2.py Outdated Show resolved Hide resolved
@@ -0,0 +1,259 @@
# coding=utf-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be super helpful to add # Copied from statements in the tests as well!

>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")

>>> inputs = image_processor(images=image, return_tensors="np")
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to do "np", since we convert "np" arrays to "jnp" arrays before calling the Flax module! (In fact, doing "np" is preferable here, since "jnp" arrays are automatically created on the accelerator device, whereas "np" is always on cpu -> creating your input on cpu and only moving it to accelerator when required is better for async dispatch)

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Aug 3, 2024

@sanchit-gandhi thanks a lot for the careful review.

A summary of the updates

  • Almost all of your suggestions have been incorporated
  • Unable to use # Copied from Beit for the PreTrainedModel class
  • Facing some issues trying to use # Copied from in the tests, and that's why they are kept as they were Now, this has been solved, everything lgtm.

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Aug 3, 2024

Regarding your PR description: I didn't fully understand what the issue was with the position embedding weights - you've defined them as a standard self.param, and the keys look to match those from PyTorch? Let me know if I'm missing something here!

The position embedding weights can be loaded perfectly. However, these weights are later modified according to the number of patches using F.interpolate (with bicubic mode) in torch. We can replicate this behavior with jax.image.scale_and_translate (or with image.resize), but it seems that this function is slight different from torch interpolate only in case of the bicubic mode, resulting slightly different output hidden_states.

Please Note:
the default config is image_size = 224, patch_size = 16 (and interpolation of the pos_embeds is not needed in this case);
but the pretrained weights config on hub is image_size = 518, patch_size = 14

@amyeroberts
Copy link
Collaborator

Thanks for the detailed explanation and iterating with us @MHRDYN7! @sanchit-gandhi is off at the moment, but I can see you've addressed his comments, so I think we're OK to merge without his second review.

Final step is running the slow tests for the model before merge. Could you push an empty commit with the message [run_slow] dinov2?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Aug 16, 2024

@amyeroberts, I've pushed the required commit. Now I guess, it requires your approval for running the slow tests

@MHRDYN7
Copy link
Contributor Author

MHRDYN7 commented Aug 18, 2024

@amyeroberts slow tests passed! Ready to be merged

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great piece of work - thanks for adding!

@amyeroberts amyeroberts merged commit 843e5e2 into huggingface:main Aug 19, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants