Skip to content

Commit

Permalink
Add a few more layers first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Roberts committed Apr 8, 2023
1 parent 9122f99 commit 1a25b23
Showing 1 changed file with 116 additions and 7 deletions.
123 changes: 116 additions & 7 deletions src/transformers/models/efficientformer/modeling_tf_efficientformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import itertools
from dataclasses import dataclass
from multiprocessing import context
from optparse import Option
from typing import Optional, Tuple, Union

import tensorflow as tf
Expand Down Expand Up @@ -164,7 +165,6 @@ def __init__(
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs):
self.activation = tf.keras.layers.Activation(tf.keras.activations.relu)

def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
features = self.convolution1(pixel_values)
features = self.convolution1(pixel_values) # TODO: check on how to apply this here
features = self.batchnorm_before(features)
features = self.convolution2(features)
features = self.batchnorm_after(features)
Expand All @@ -248,13 +248,65 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:


class TFEfficientFormerDenseMlp(tf.keras.layers.Layer):
pass
def __init__(
self,
config: EfficientFormerConfig,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
**kwargs
):
super().__init__(**kwargs)
out_features = out_features or in_features
hidden_features = hidden_features or in_features

self.linear_in = tf.keras.layers.Dense(units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="dense_in")
self.activation = get_tf_activation(config.hidden_act)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.linear_out = tf.keras.layers.Dense(units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="dense_out")

def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.linear_in(inputs=hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.linear_out(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)

return hidden_states


class TFEfficientFormerConvMlp(tf.keras.layers.Layer):
pass
def __init__(self,
config: EfficientFormerConfig,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
drop: float = 0.0,
**kwargs):
super().__init__(**kwargs)
out_features = out_features or in_features
hidden_features = hidden_features or in_features

self.convolution1 = tf.leras.layers.Conv2D(hidden_features,
kernel_size=1,
name="conv1"
) #TODO: what strides and kernel-size
self.activation = get_tf_activation(config.hidden_act)
self.convolution2 = tf.keras.layers.Conv2D(out_features, kernel_size=1, name="conv2")
self.dropout = tf.keras.layers.Dropout(rate=drop)

self.batchnorm_before = tf.keras.layers.BatchNormalization(axis=1)
self.batchnorm_after = tf.keras.layers.BatchNormalization(axis=1)

def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.convolution1(hidden_states) #TODO: Check efficientnet code
hidden_states = self.batchnorm_before(hidden_states)

hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states


# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath
class TFEfficientFormerDropPath(tf.keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Expand All @@ -277,13 +329,70 @@ def call(self, x, training=None):


class TFEfficientFormerFlat(tf.keras.layers.Layer):
def __init__(self):
pass
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.flatten = tf.keras.layers.Flatten()

def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]:
hidden_states = self.flatten(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return hidden_states


class TFEfficientFormerMeta3D(tf.keras.layers.Layer):
pass
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
super().__init__()

self.token_mixer = TFEfficientFormerSelfAttention(
dim=config.dim,
key_dim=config.key_dim,
num_heads=config.num_attention_heads,
attention_ratio=config.attention_ratio,
resolution=config.resolution,
)
self.dim = dim
self.config = config
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2")
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)

self.drop_path = TFEfficientFormerDropPath(drop_path) if drop_path > 0.0 else tf.keras.layers.Identity()
self.use_layer_scale = config.use_layer_scale

def build(self, input_shape: tf.TensorShape):
self.layer_scale_1 = None
self.layer_scale_2 = None

if self.config.use_layer_scale:
self.layer_scale_1 = self.add_weight(
shape=(self.dim), initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), trainable=True, name="layer_scale_1"
)
self.layer_scale_2 = self.add_weight(
shape=(self.dim), initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), trainable=True, name="layer_scale_2"
)
super().build(input_shape)

def call(
self,
hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
) -> Tuple[tf.Tensor]:
self_attention_outputs = self.token_mixer(self.layernorm1(hidden_states), output_attentions)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

if self.use_layer_scale:
layer_output = hidden_states + self.drop_path(
tf.multiply(tf.expand_dims(tf.expand_dims(self.layer_scale_1, axis=0), axis=0), attention_output), training=training)
layer_output = layer_output + self.drop_path(
tf.multiply(tf.expand_dims(tf.expand_dims(self.layer_scale_2, axis=0), axis=0)), self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training))
else:
layer_output = hidden_states + self.drop_path(attention_output, training=training)
layer_output = layer_output + self.drop_path(self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training))

outputs = (layer_output, ) + outputs

return outputs

class TFEfficientFormerMeta3DLayers(tf.keras.layers.Layer):
pass
Expand Down

0 comments on commit 1a25b23

Please sign in to comment.