Skip to content

Commit

Permalink
Add Flax Dinov2 (#31960)
Browse files Browse the repository at this point in the history
* tfmsenv restored in main

* installed flax

* forward pass done and all tests passed

* make fix-copies and cleaning the scripts

* fixup attempt 1

* fixup attempt 2

* fixup third attempt

* fixup attempt 4

* fixup attempt 5

* dinov2 doc fixed

* FlaxDinov2Model + ForImageClassification added to OBJECTS_TO_IGNORE

* external pos_encoding layer removed

* fixup attempt 6

* fixed integration test values

* fixup attempt 7

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: amyeroberts <[email protected]>

* comments removed

* comment removed from the test

* fixup

* Update src/transformers/models/dinov2/modeling_flax_dinov2.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* new fixes 1

* interpolate_pos_encoding function removed

* droppath rng fixed, pretrained beit copied-from still not working

* modeling_flax_dinov2.py reformatted

* Update tests/models/dinov2/test_modeling_flax_dinov2.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* added Copied from, to the tests

* copied from statements removed from tests

* fixed copied from statements in the tests

* [run_slow] dinov2

---------

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 52cb403 commit 843e5e2
Show file tree
Hide file tree
Showing 9 changed files with 1,141 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DETR](model_doc/detr) ||||
| [DialoGPT](model_doc/dialogpt) ||||
| [DiNAT](model_doc/dinat) ||||
| [DINOv2](model_doc/dinov2) ||| |
| [DINOv2](model_doc/dinov2) ||| |
| [DistilBERT](model_doc/distilbert) ||||
| [DiT](model_doc/dit) ||||
| [DonutSwin](model_doc/donut) ||||
Expand Down
20 changes: 20 additions & 0 deletions docs/source/en/model_doc/dinov2.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ If you're interested in submitting a resource to be included here, please feel f

[[autodoc]] Dinov2Config

<frameworkcontent>
<pt>

## Dinov2Model

[[autodoc]] Dinov2Model
Expand All @@ -81,3 +84,20 @@ If you're interested in submitting a resource to be included here, please feel f

[[autodoc]] Dinov2ForImageClassification
- forward

</pt>
<jax>

## FlaxDinov2Model

[[autodoc]] FlaxDinov2Model
- __call__


## FlaxDinov2ForImageClassification

[[autodoc]] FlaxDinov2ForImageClassification
- __call__

</jax>
</frameworkcontent>
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4587,6 +4587,13 @@
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.dinov2"].extend(
[
"FlaxDinov2Model",
"FlaxDinov2ForImageClassification",
"FlaxDinov2PreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"FlaxDistilBertForMaskedLM",
Expand Down Expand Up @@ -8706,6 +8713,11 @@
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.dinov2 import (
FlaxDinov2ForImageClassification,
FlaxDinov2Model,
FlaxDinov2PreTrainedModel,
)
from .models.distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("dinov2", "FlaxDinov2Model"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gemma", "FlaxGemmaModel"),
Expand Down Expand Up @@ -124,6 +125,7 @@
[
# Model for Image-classsification
("beit", "FlaxBeitForImageClassification"),
("dinov2", "FlaxDinov2ForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"),
Expand Down
25 changes: 25 additions & 0 deletions src/transformers/models/dinov2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_torch_available,
)

Expand All @@ -35,6 +36,18 @@
"Dinov2Backbone",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_dinov2"] = [
"FlaxDinov2ForImageClassification",
"FlaxDinov2Model",
"FlaxDinov2PreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_dinov2 import Dinov2Config, Dinov2OnnxConfig

Expand All @@ -51,6 +64,18 @@
Dinov2PreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_dinov2 import (
FlaxDinov2ForImageClassification,
FlaxDinov2Model,
FlaxDinov2PreTrainedModel,
)

else:
import sys

Expand Down
Loading

0 comments on commit 843e5e2

Please sign in to comment.