Skip to content

Commit

Permalink
Add serving_output implementation and fix layer_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
nandwalritik committed Apr 17, 2023
1 parent 592d47f commit 983965e
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,17 +1599,15 @@ def __init__(self, config):
super().__init__(config)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.num_layers = config.num_hidden_layers + 1
with tf.name_scope(self._name_scope()):
if config.use_weighted_layer_sum:
self.layer_weights = self.add_weight(
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
)
self.config = config
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector")
self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")

def build(self, input_shape):
if self.config.use_weighted_layer_sum:
self.layer_weights = self.add_weight(
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
)
super().build(input_shape)

def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
Expand Down Expand Up @@ -1690,6 +1688,16 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output):
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hidden_states,
attentions=attentions,
)

@tf.function(
input_signature=[
{
Expand Down

0 comments on commit 983965e

Please sign in to comment.