-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TensorFlow Wav2Vec2 for sequence classification #22073
Merged
Rocketknight1
merged 7 commits into
huggingface:main
from
nandwalritik:add_wav2vec2_seq_classification
Apr 26, 2023
Merged
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
d3f74c2
Add initial changes for TF wav2vec2 for sequence classification
nandwalritik 8252b9f
Add suggested changes
nandwalritik 592d47f
Add serving and serving output methods
nandwalritik 983965e
Add serving_output implementation and fix layer_weights
nandwalritik c22aef5
Add fixes
nandwalritik 7a92f7d
Fixed test cases
nandwalritik e6080ac
Fixing test and adding suggested changes
nandwalritik File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
import tensorflow as tf | ||
|
||
from ...activations_tf import get_tf_activation | ||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput | ||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput | ||
from ...modeling_tf_utils import ( | ||
TFPreTrainedModel, | ||
get_initializer, | ||
|
@@ -1212,6 +1212,46 @@ def serving(self, inputs): | |
|
||
return self.serving_output(output) | ||
|
||
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None): | ||
""" | ||
Computes the output length of the convolutional layers | ||
""" | ||
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter | ||
|
||
def _conv_out_length(input_length, kernel_size, stride): | ||
return tf.math.floordiv(input_length - kernel_size, stride) + 1 | ||
|
||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): | ||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) | ||
|
||
if add_adapter: | ||
for _ in range(self.config.num_adapter_layers): | ||
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) | ||
return input_lengths | ||
|
||
def _get_feature_vector_attention_mask( | ||
self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None | ||
): | ||
non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1] | ||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) | ||
output_lengths = tf.cast(output_lengths, tf.int32) | ||
batch_size = tf.shape(attention_mask)[0] | ||
# check device here | ||
attention_mask = tf.zeros( | ||
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask" | ||
) # these two operations makes sure that all values before the output lengths idxs are attended to | ||
## check device | ||
attention_mask = tf.tensor_scatter_nd_update( | ||
attention_mask, | ||
indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1), | ||
updates=tf.ones([batch_size], dtype=attention_mask.dtype), | ||
) | ||
attention_mask = tf.reverse(attention_mask, axis=[-1]) | ||
attention_mask = tf.cumsum(attention_mask, axis=-1) | ||
attention_mask = tf.reverse(attention_mask, axis=[-1]) | ||
nandwalritik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
attention_mask = tf.cast(attention_mask, tf.bool) | ||
return attention_mask | ||
|
||
|
||
WAV_2_VEC_2_START_DOCSTRING = r""" | ||
|
||
|
@@ -1552,3 +1592,123 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: | |
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None | ||
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None | ||
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) | ||
|
||
|
||
class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also looks good to me (matching PyTorch) - will leave the TF specificities again to @Rocketknight1 |
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") | ||
self.num_layers = config.num_hidden_layers + 1 | ||
with tf.name_scope(self._name_scope()): | ||
if config.use_weighted_layer_sum: | ||
self.layer_weights = self.add_weight( | ||
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights" | ||
) | ||
self.config = config | ||
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector") | ||
self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier") | ||
|
||
def freeze_feature_extractor(self): | ||
""" | ||
Calling this function will disable the gradient computation for the feature encoder so that its parameters will | ||
not be updated during training. | ||
""" | ||
warnings.warn( | ||
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." | ||
"Please use the equivalent `freeze_feature_encoder` method instead.", | ||
FutureWarning, | ||
) | ||
self.freeze_feature_encoder() | ||
|
||
def freeze_feature_encoder(self): | ||
""" | ||
Calling this function will disable the gradient computation for the feature encoder so that its parameter will | ||
not be updated during training. | ||
""" | ||
self.wav2vec2.feature_extractor.trainable = False | ||
|
||
def freeze_base_model(self): | ||
""" | ||
Calling this function will disable the gradient computation for the base model so that its parameters will not | ||
be updated during training. Only the classification head will be updated. | ||
""" | ||
for layer in self.wav2vec2.layers: | ||
layer.trainable = False | ||
|
||
@unpack_inputs | ||
def call( | ||
self, | ||
input_values: tf.Tensor, | ||
attention_mask: Optional[tf.Tensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
labels: Optional[tf.Tensor] = None, | ||
): | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states | ||
|
||
outputs = self.wav2vec2( | ||
input_values, | ||
attention_mask=attention_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
if self.config.use_weighted_layer_sum: | ||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] | ||
hidden_states = tf.stack(hidden_states, axis=1) | ||
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1) | ||
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1) | ||
else: | ||
hidden_states = outputs[0] | ||
|
||
hidden_states = self.projector(hidden_states) | ||
if attention_mask is None: | ||
pooled_output = tf.reduce_mean(hidden_states, axis=1) | ||
else: | ||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) | ||
padding_mask_float = tf.cast(padding_mask, hidden_states.dtype) | ||
hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1)) | ||
pooled_output = tf.divide( | ||
tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1) | ||
) | ||
logits = self.classifier(pooled_output) | ||
loss = None | ||
if labels is not None: | ||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels])) | ||
if not return_dict: | ||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return TFSequenceClassifierOutput( | ||
loss=loss, | ||
logits=logits, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) | ||
|
||
def serving_output(self, output): | ||
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None | ||
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None | ||
|
||
return TFSequenceClassifierOutput( | ||
logits=output.logits, | ||
hidden_states=hidden_states, | ||
attentions=attentions, | ||
) | ||
|
||
@tf.function( | ||
input_signature=[ | ||
{ | ||
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"), | ||
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), | ||
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), | ||
} | ||
] | ||
) | ||
def serving(self, inputs): | ||
output = self.call(input_values=inputs) | ||
|
||
return self.serving_output(output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - will leave the TF specificities to @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct, as far as I can see!
tensor_scatter_nd_update
is TF's equivalent to JAX's array assignment with the.at[].set()
operation.