Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow Wav2Vec2 for sequence classification #22073

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/wav2vec2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower
[[autodoc]] TFWav2Vec2Model
- call

## TFWav2Vec2ForSequenceClassification

[[autodoc]] TFWav2Vec2ForSequenceClassification
- call

## TFWav2Vec2ForCTC

[[autodoc]] TFWav2Vec2ForCTC
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,6 +3381,7 @@
[
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2ForSequenceClassification",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
Expand Down Expand Up @@ -6509,6 +6510,7 @@
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])

TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
Expand Down Expand Up @@ -468,6 +469,9 @@
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)


class TFAutoModel(_BaseAutoModelClass):
Expand All @@ -477,6 +481,15 @@ class TFAutoModel(_BaseAutoModelClass):
TFAutoModel = auto_class_update(TFAutoModel)


class TFAutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


TFAutoModelForAudioClassification = auto_class_update(
TFAutoModelForAudioClassification, head_doc="audio classification"
)


class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
"TFWav2Vec2ForSequenceClassification",
]

try:
Expand Down Expand Up @@ -108,6 +109,7 @@
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
Expand Down
162 changes: 161 additions & 1 deletion src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tensorflow as tf

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
TFPreTrainedModel,
get_initializer,
Expand Down Expand Up @@ -1212,6 +1212,46 @@ def serving(self, inputs):

return self.serving_output(output)

def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
return tf.math.floordiv(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths

def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None
):
non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = tf.cast(output_lengths, tf.int32)
batch_size = tf.shape(attention_mask)[0]
# check device here
attention_mask = tf.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
) # these two operations makes sure that all values before the output lengths idxs are attended to
## check device
attention_mask = tf.tensor_scatter_nd_update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - will leave the TF specificities to @Rocketknight1!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct, as far as I can see! tensor_scatter_nd_update is TF's equivalent to JAX's array assignment with the .at[].set() operation.

attention_mask,
indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1),
updates=tf.ones([batch_size], dtype=attention_mask.dtype),
)
attention_mask = tf.reverse(attention_mask, axis=[-1])
attention_mask = tf.cumsum(attention_mask, axis=-1)
attention_mask = tf.reverse(attention_mask, axis=[-1])
nandwalritik marked this conversation as resolved.
Show resolved Hide resolved
attention_mask = tf.cast(attention_mask, tf.bool)
return attention_mask


WAV_2_VEC_2_START_DOCSTRING = r"""

Expand Down Expand Up @@ -1552,3 +1592,123 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)


class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also looks good to me (matching PyTorch) - will leave the TF specificities again to @Rocketknight1

def __init__(self, config):
super().__init__(config)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.num_layers = config.num_hidden_layers + 1
with tf.name_scope(self._name_scope()):
if config.use_weighted_layer_sum:
self.layer_weights = self.add_weight(
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
)
self.config = config
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector")
self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")

def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()

def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2.feature_extractor.trainable = False

def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for layer in self.wav2vec2.layers:
layer.trainable = False

@unpack_inputs
def call(
self,
input_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = tf.stack(hidden_states, axis=1)
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = outputs[0]

hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = tf.reduce_mean(hidden_states, axis=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
padding_mask_float = tf.cast(padding_mask, hidden_states.dtype)
hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1))
pooled_output = tf.divide(
tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1)
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output

return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def serving_output(self, output):
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hidden_states,
attentions=attentions,
)

@tf.function(
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(input_values=inputs)

return self.serving_output(output)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFWav2Vec2ForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFWav2Vec2Model(metaclass=DummyObject):
_backends = ["tf"]

Expand Down
Loading