Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ported Dinov2 to flax #25579

Closed
wants to merge 11 commits into from
Closed

Ported Dinov2 to flax #25579

wants to merge 11 commits into from

Conversation

ifeherva
Copy link

Ported the Dinov2 model to jax/flax

This PR adds the dinov2 model in flax. It is based on the vit flax port but uses the existing pytorch dinov2 as base.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@sanchit-gandhi @amyeroberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model!

V. nice and easy to read PR. Mostly just nits. Main comments are to add the copied from statements and integration tests.

tests/models/dinov2/test_modeling_flax_dinov2.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be integration tests for the model e.g. like these for beit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this!

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
def setup(self):
out_features = self.config.hidden_size
hidden_features = int(self.config.hidden_size * self.config.mlp_ratio)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where re these numbers - 2/3, 7, 8 * 8 coming from?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the pytorch implementation which seems to be copying from the original repo: https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py#L57

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great already! Thanks for such a clean PR @ifeherva 🙌 Echo'ing @amyeroberts's points about using # Copied from statements where possible, and adding a few slow integration tests to check that we get the same outputs as the PyTorch model when using real checkpoints (just need to assert that the values we get out are match an expected array, where the expected array is the same as the PyTorch outputs)



DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dinov2-base",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you pushed the Flax weights to a pull request on this repo? It would be nice to do this in tandem with this PR!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, I still need to do that I guess :)


def setup(self):
self.lambda1 = self.param(
"lambda1", jax.nn.initializers.constant(self.config.layerscale_value), (self.config.hidden_size,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this parameter always be in float32 precision, or should it respect the dtype of the model? Usually, we cast everything to the dtype of the model, such that the forward computation is done in the specified dtype

Otherwise, we might upcast inadvertantly to a higher dtype than the model dtype during the forward pass

Suggested change
"lambda1", jax.nn.initializers.constant(self.config.layerscale_value), (self.config.hidden_size,)
"lambda1", jax.nn.initializers.constant(self.config.layerscale_value, dtype=self.dtype), (self.config.hidden_size,)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that is intended, there is even an explicit test for checking if all params are initialized to float32 (test_default_params_dtype). If I add the proposed line above the test will fail.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't fail if the attribute self.dtype is float32 no? Then it'll be initialised in float32? Currently, in the PyTorch version, if we send the model to bfloat16, then this parameter lambda1 is also in bfloat16. To have equivalence in Flax, we need to pass the dtype attribute to this param, such that we can put the Flax weights in bfloat16 as required

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case the unit test might be broken. If I pass self.dtype, it actually initializes it in float16 in the unit test and then complain why it is not float32 :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's strange - could you check that the attribute dtype is being passed down correctly from the top level modules to the lower level ones? (e.g. you could just print out self.dtype for all of the modules and see where it goes from fp32 -> fp16)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is passed down correctly, however this test expects the params to be float32 even when you initialize it with float16: https://github.com/huggingface/transformers/blob/main/tests/test_modeling_flax_common.py#L789
Not sure I really understand the logic here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into this more - you're correct in that we should not change the dtype of the params! We should only change the dtype of the computation. I'll close this thread and suggest the appropriate fix

src/transformers/models/dinov2/modeling_flax_dinov2.py Outdated Show resolved Hide resolved
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-patch16-224")
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-patch16-224")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that this checkpoint has flax weights uploaded

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't exist so I will rework this description. Thanks for flagging it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to open a PR on the Hub to add the Flax weights. You can load them from_pretrained with from_pt=True:

model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-patch16-224", from_pt=True

And then push the weights to the Hub with:

model.push_to_hub("facebook/dinov2-base-patch16-224", create_pr=True)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now only the base model has weights. I converted that and opened a PR on the hub: https://huggingface.co/facebook/dinov2-base/discussions/5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Thanks for opening a PR to upload them to the Hub, that's great! It would be cool to make sure the fine-tuned version has weights before we merge this PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure such weights are in the public domain. At least I couldn't find them... :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case let's just update the code snippet to use the model variant that has the weights - could you use the same checkpoint that is used in the PyTorch dinov2 code possibly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I changed the code yesterday to use the ..-base model for now. Once this PR gets merged I can convert the other (larger) ones as well and push on the hub.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool - sounds good :)

tests/models/dinov2/test_modeling_flax_dinov2.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating @ifeherva - a few pending points but otherwise looking good!


def setup(self):
self.lambda1 = self.param(
"lambda1", jax.nn.initializers.constant(self.config.layerscale_value), (self.config.hidden_size,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into this more - you're correct in that we should not change the dtype of the params! We should only change the dtype of the computation. I'll close this thread and suggest the appropriate fix

)

def __call__(self, hidden_state):
return hidden_state * self.lambda1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return hidden_state * self.lambda1
hidden_state = hidden_state * self.lambda1
return hidden_state.astype(self.dtype)


>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still pending! If you could address in a similar way to the pytorch code:

>>> list(feature_maps[-1].shape)
[1, 768, 16, 16]

Suggested change
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
>>> model.config.id2label[predicted_class_idx.item()])
# put predicted class here (without hash symbol)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Oct 20, 2023
@sanchit-gandhi
Copy link
Contributor

Hey @ifeherva! Given you made a great start on this PR, it would be super nice if you were able to see it to completion! On hand to help with any questions/queries 🤗 Of course if you are busy, there's no pressure to finish this. In this case, we can open it up to the community to see if anyone is able to finish the integration so that this work is merged into main

@MHRDYN7 MHRDYN7 mentioned this pull request Jul 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants