Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ported Dinov2 to flax #25579

Closed
wants to merge 11 commits into from
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ Flax), PyTorch, and/or TensorFlow.
| DETA | ✅ | ❌ | ❌ |
| DETR | ✅ | ❌ | ❌ |
| DiNAT | ✅ | ❌ | ❌ |
| DINOv2 | ✅ | ❌ | |
| DINOv2 | ✅ | ❌ | |
| DistilBERT | ✅ | ✅ | ✅ |
| DonutSwin | ✅ | ❌ | ❌ |
| DPR | ✅ | ✅ | ❌ |
Expand Down
11 changes: 11 additions & 0 deletions docs/source/en/model_doc/dinov2.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ The original code can be found [here](https://github.com/facebookresearch/dinov2

[[autodoc]] Dinov2ForImageClassification
- forward

## FlaxDinov2Model

[[autodoc]] FlaxDinov2Model
- __call__

## FlaxDinov2ForImageClassification

[[autodoc]] FlaxDinov2ForImageClassification
- __call__

8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4012,6 +4012,13 @@
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.dinov2"].extend(
[
"FlaxDinov2ForImageClassification",
"FlaxDinov2Model",
"FlaxDinov2PreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"FlaxDistilBertForMaskedLM",
Expand Down Expand Up @@ -7476,6 +7483,7 @@
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.dinov2 import FlaxDinov2ForImageClassification, FlaxDinov2Model, FlaxDinov2PreTrainedModel
from .models.distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/modeling_flax_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,31 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.

Args:
loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[jnp.ndarray] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("dinov2", "FlaxDinov2Model"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gpt-sw3", "FlaxGPT2Model"),
Expand Down Expand Up @@ -122,6 +123,7 @@
[
# Model for Image-classsification
("beit", "FlaxBeitForImageClassification"),
("dinov2", "FlaxDinov2ForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"),
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/models/dinov2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_torch_available,
)

Expand All @@ -38,6 +39,18 @@
"Dinov2Backbone",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_dinov2"] = [
"FlaxDinov2ForImageClassification",
"FlaxDinov2Model",
"FlaxDinov2PreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_dinov2 import DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Dinov2Config, Dinov2OnnxConfig

Expand All @@ -55,6 +68,14 @@
Dinov2PreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_dinov2 import FlaxDinov2ForImageClassification, FlaxDinov2Model, FlaxDinov2PreTrainedModel

else:
import sys

Expand Down
Loading