diff --git a/docs/source/en/model_doc/wav2vec2.mdx b/docs/source/en/model_doc/wav2vec2.mdx index 837e4526e858f5..166f6bb36c8ef3 100644 --- a/docs/source/en/model_doc/wav2vec2.mdx +++ b/docs/source/en/model_doc/wav2vec2.mdx @@ -197,6 +197,11 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower [[autodoc]] TFWav2Vec2Model - call +## TFWav2Vec2ForSequenceClassification + +[[autodoc]] TFWav2Vec2ForSequenceClassification + - call + ## TFWav2Vec2ForCTC [[autodoc]] TFWav2Vec2ForCTC diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 75ff4e8345a26d..bddc24719a01ee 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3381,6 +3381,7 @@ [ "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TFWav2Vec2ForCTC", + "TFWav2Vec2ForSequenceClassification", "TFWav2Vec2Model", "TFWav2Vec2PreTrainedModel", ] @@ -6509,6 +6510,7 @@ from .models.wav2vec2 import ( TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TFWav2Vec2ForCTC, + TFWav2Vec2ForSequenceClassification, TFWav2Vec2Model, TFWav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index caf5ba71dc0318..d23f366e2182ea 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -348,6 +348,7 @@ ("xlnet", "TFXLNetForQuestionAnsweringSimple"), ] ) +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")]) TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ @@ -468,6 +469,9 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES ) +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) class TFAutoModel(_BaseAutoModelClass): @@ -477,6 +481,15 @@ class TFAutoModel(_BaseAutoModelClass): TFAutoModel = auto_class_update(TFAutoModel) +class TFAutoModelForAudioClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING + + +TFAutoModelForAudioClassification = auto_class_update( + TFAutoModelForAudioClassification, head_doc="audio classification" +) + + class TFAutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index b55013cf54dd25..b3abdb99ec722d 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -59,6 +59,7 @@ "TFWav2Vec2ForCTC", "TFWav2Vec2Model", "TFWav2Vec2PreTrainedModel", + "TFWav2Vec2ForSequenceClassification", ] try: @@ -108,6 +109,7 @@ from .modeling_tf_wav2vec2 import ( TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TFWav2Vec2ForCTC, + TFWav2Vec2ForSequenceClassification, TFWav2Vec2Model, TFWav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index 64defa33597c66..dcc59d7f7322aa 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -22,7 +22,7 @@ import tensorflow as tf from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput +from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput from ...modeling_tf_utils import ( TFPreTrainedModel, get_initializer, @@ -1212,6 +1212,46 @@ def serving(self, inputs): return self.serving_output(output) + def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None): + """ + Computes the output length of the convolutional layers + """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + return tf.math.floordiv(input_length - kernel_size, stride) + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None + ): + non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = tf.cast(output_lengths, tf.int32) + batch_size = tf.shape(attention_mask)[0] + # check device here + attention_mask = tf.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask" + ) # these two operations makes sure that all values before the output lengths idxs are attended to + ## check device + attention_mask = tf.tensor_scatter_nd_update( + attention_mask, + indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1), + updates=tf.ones([batch_size], dtype=attention_mask.dtype), + ) + attention_mask = tf.reverse(attention_mask, axis=[-1]) + attention_mask = tf.cumsum(attention_mask, axis=-1) + attention_mask = tf.reverse(attention_mask, axis=[-1]) + attention_mask = tf.cast(attention_mask, tf.bool) + return attention_mask + WAV_2_VEC_2_START_DOCSTRING = r""" @@ -1552,3 +1592,125 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) + + +class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") + self.num_layers = config.num_hidden_layers + 1 + with tf.name_scope(self._name_scope()): + if config.use_weighted_layer_sum: + self.layer_weights = self.add_weight( + shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights" + ) + self.config = config + self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector") + self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier") + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor.trainable = False + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for layer in self.wav2vec2.layers: + layer.trainable = False + + @unpack_inputs + def call( + self, + input_values: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[tf.Tensor] = None, + training: bool = False, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = tf.stack(hidden_states, axis=1) + norm_weights = tf.nn.softmax(self.layer_weights, axis=-1) + hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = tf.reduce_mean(hidden_states, axis=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + padding_mask_float = tf.cast(padding_mask, hidden_states.dtype) + hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1)) + pooled_output = tf.divide( + tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1) + ) + logits = self.classifier(pooled_output) + loss = None + if labels is not None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels])) + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output): + hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFSequenceClassifierOutput( + logits=output.logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + @tf.function( + input_signature=[ + { + "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), + } + ] + ) + def serving(self, inputs): + output = self.call(input_values=inputs) + + return self.serving_output(output) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 94d881ac75d3b9..9e84bb0dea0d76 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2538,6 +2538,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFWav2Vec2ForSequenceClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFWav2Vec2Model(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 3bb3d36cbfb211..8afcd66d8feeb8 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -47,7 +47,13 @@ if is_tf_available(): import tensorflow as tf - from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor + from transformers import ( + AutoFeatureExtractor, + TFWav2Vec2ForCTC, + TFWav2Vec2ForSequenceClassification, + TFWav2Vec2Model, + Wav2Vec2Processor, + ) from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices @@ -244,6 +250,29 @@ def check_ctc_loss(self, config, input_values, *args): self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2) + def check_seq_classifier_loss(self, loss, config, input_values, *args): + model = TFWav2Vec2ForSequenceClassification(config) + + input_values = input_values[:3] + attention_mask = tf.ones(input_values.shape, dtype=tf.int32) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = tf.random.uniform((input_values.shape[0],), maxval=len(model.config.id2label), dtype=tf.int32) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + training = False + masked_loss = ( + model(input_values, attention_mask=attention_mask, labels=labels, training=training).loss.numpy().item() + ) + unmasked_loss = model(input_values, labels=labels, training=training).loss.numpy().item() + + assert isinstance(masked_loss, float) + assert isinstance(unmasked_loss, float) + assert masked_loss != unmasked_loss + def check_training(self, config, input_values, *args): model = TFWav2Vec2ForCTC(config) @@ -283,8 +312,14 @@ def prepare_config_and_inputs_for_common(self): @require_tf class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else () - pipeline_model_mapping = {"feature-extraction": TFWav2Vec2Model} if is_tf_available() else {} + all_model_classes = ( + (TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else () + ) + pipeline_model_mapping = ( + {"feature-extraction": TFWav2Vec2Model, "audio-classification": TFWav2Vec2ForSequenceClassification} + if is_tf_available() + else {} + ) test_resize_embeddings = False test_head_masking = False test_onnx = False @@ -400,7 +435,9 @@ def test_keras_fit(self): @require_tf class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): - all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else () + all_model_classes = ( + (TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else () + ) test_resize_embeddings = False test_head_masking = False test_onnx = False @@ -564,6 +601,11 @@ def _load_datasamples(self, num_samples): return [x["array"] for x in speech_samples] + def _load_superb(self, task, num_samples): + ds = load_dataset("anton-l/superb_dummy", task, split="test") + + return ds[:num_samples] + def test_inference_ctc_normal(self): model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) @@ -676,3 +718,87 @@ def test_wav2vec2_with_lm_pool(self): @require_librosa def test_wav2vec2_with_lm_invalid_pool(self): run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None) + + def test_inference_keyword_spotting(self): + model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True) + processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks") + input_data = self._load_superb("ks", 4) + inputs = processor(input_data["speech"], return_tensors="tf", padding=True) + input_values = inputs.input_values + attention_mask = inputs.attention_mask + outputs = model(input_values, attention_mask) + predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax( + outputs.logits, axis=-1 + ) + expected_labels = [7, 6, 10, 9] + expected_logits = tf.convert_to_tensor([6.1186, 11.8961, 10.2931, 6.0898]) + self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels) + self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2)) + + def test_inference_intent_classification(self): + model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic", from_pt=True) + processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic") + input_data = self._load_superb("ic", 4) + inputs = processor(input_data["speech"], return_tensors="tf", padding=True) + input_values = inputs.input_values + attention_mask = inputs.attention_mask + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits_action, predicted_ids_action = tf.math.reduce_max(outputs.logits[:, :6], axis=-1), tf.argmax( + outputs.logits[:, :6], axis=-1 + ) + predicted_logits_object, predicted_ids_object = tf.math.reduce_max( + outputs.logits[:, 6:20], axis=-1 + ), tf.argmax(outputs.logits[:, 6:20], axis=-1) + predicted_logits_location, predicted_ids_location = tf.math.reduce_max( + outputs.logits[:, 20:24], axis=-1 + ), tf.argmax(outputs.logits[:, 20:24], axis=-1) + expected_labels_action = [0, 0, 2, 3] + expected_logits_action = tf.convert_to_tensor([0.4568, 11.0848, 1.6621, 9.3841]) + expected_labels_object = [3, 10, 3, 4] + expected_logits_object = tf.convert_to_tensor([1.5322, 10.7094, 5.2469, 22.1318]) + expected_labels_location = [0, 0, 0, 1] + expected_logits_location = tf.convert_to_tensor([1.5335, 6.5096, 10.5704, 11.0569]) + + self.assertListEqual(predicted_ids_action.numpy().tolist(), expected_labels_action) + self.assertListEqual(predicted_ids_object.numpy().tolist(), expected_labels_object) + self.assertListEqual(predicted_ids_location.numpy().tolist(), expected_labels_location) + + self.assertTrue(np.allclose(predicted_logits_action, expected_logits_action, atol=1e-2)) + self.assertTrue(np.allclose(predicted_logits_object, expected_logits_object, atol=1e-2)) + self.assertTrue(np.allclose(predicted_logits_location, expected_logits_location, atol=1e-2)) + + def test_inference_speaker_identification(self): + model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid", from_pt=True) + processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid") + input_data = self._load_superb("si", 4) + output_logits = [] + for example in input_data["speech"]: + input = processor(example, return_tensors="tf", padding=True) + output = model(input.input_values, attention_mask=None) + output_logits.append(output.logits[0]) + output_logits = tf.stack(output_logits) + predicted_logits, predicted_ids = tf.math.reduce_max(output_logits, axis=-1), tf.argmax(output_logits, axis=-1) + expected_labels = [251, 1, 1, 3] + expected_logits = tf.convert_to_tensor([37.5627, 71.6362, 64.2419, 31.7778]) + self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels) + self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2)) + + def test_inference_emotion_recognition(self): + model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er", from_pt=True) + processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er") + input_data = self._load_superb("er", 4) + inputs = processor(input_data["speech"], return_tensors="tf", padding=True) + + input_values = inputs.input_values + attention_mask = inputs.attention_mask + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax( + outputs.logits, axis=-1 + ) + + expected_labels = [1, 1, 2, 2] + # s3prl logits for the same batch + expected_logits = tf.convert_to_tensor([2.1722, 3.0779, 8.0287, 6.6797]) + + self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels) + self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))