Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow Wav2Vec2 for sequence classification #22073

Merged

Conversation

nandwalritik
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

@sanchit-gandhi

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@github-actions
Copy link

github-actions bot commented Apr 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sgugger
Copy link
Collaborator

sgugger commented Apr 10, 2023

Kindly ping @sanchit-gandhi and adding @Rocketknight1 for the TensorFlow side.

@Rocketknight1
Copy link
Member

Hi @nandwalritik, and sorry for the extremely long delay in catching this! Ordinarily one of the TF maintainers reviews TF pull requests, but this one slipped through the cracks somehow. If you want to file TF PRs in future, you can directly ping me or @gante to make sure that we don't miss it.

This PR actually looks almost perfect, but there are a couple of TF-specific details that are causing some tests to fail. I'll mark them in a code review in just a sec, but they shouldn't take too long to fix. Thanks again for submitting this!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! A few tweaks in the __init__ should fix most of the issues.

The only other thing missing is a serving and serving_output method. These are properties that are unique to TF models, and indicate the input and output signatures to enable model compilation and exporting. You can look at other models in the library to get a sense for how they work, but if you can't figure it out let me know and I'll add it for you!

src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Outdated Show resolved Hide resolved
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Outdated Show resolved Hide resolved
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Outdated Show resolved Hide resolved
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Outdated Show resolved Hide resolved
@nandwalritik nandwalritik force-pushed the add_wav2vec2_seq_classification branch from 94d74f8 to 8252b9f Compare April 12, 2023 04:46
@nandwalritik
Copy link
Contributor Author

for serving and serving_output methods I added changes, but now sure they are correct or not.

@Rocketknight1
Copy link
Member

Hi @nandwalritik, I'm seeing the issue when you move it to build() - the problem is the weight name, as it usually is in our TensorFlow ports! TF isn't very consistent about the name scope used for weights, and it can differ depending on when the weight is created in the init, the build or lazily in the call(), which makes it tricky because we use the names to match weights between PT and TF models.

I'll see if I can push a solution to your repo, hang on.

@nandwalritik
Copy link
Contributor Author

Ok

@Rocketknight1
Copy link
Member

Try:

with tf.name_scope(self._name_scope()):
    self.layer_weights = self.add_weight(
        shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
    )

in the __init__, not the build(). I know that contradicts what I said earlier, but it turns out to be a bit different for a base model class than a sublayer.

I also see a couple of other errors - you can see them by clicking the Details beside tests_tf in the checklist at the bottom of this PR. If you can't figure out what's causing them, ping me over the weekend or on Monday and I'll try to debug them!

@nandwalritik nandwalritik force-pushed the add_wav2vec2_seq_classification branch from 3da119b to 983965e Compare April 17, 2023 04:51
@nandwalritik
Copy link
Contributor Author

Try:

with tf.name_scope(self._name_scope()):
    self.layer_weights = self.add_weight(
        shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
    )

in the __init__, not the build(). I know that contradicts what I said earlier, but it turns out to be a bit different for a base model class than a sublayer.

I also see a couple of other errors - you can see them by clicking the Details beside tests_tf in the checklist at the bottom of this PR. If you can't figure out what's causing them, ping me over the weekend or on Monday and I'll try to debug them!

Ok, so after adding this change, the weights are getting loaded without any warning or error, but the output of pytorch and tensorflow model doesn't have rtol of 1e-5.
Although I checked shape and absolute sum of tensors of both the models they are almost equal

PT model 
1,292,768 -> 29877.8750


1,292,256 -> 29711.7109

pooled_output
1,256 -> 38.7491



TF model

hidden_state
1,292,768 -> 29877.879

1,292,256 -> 29711.715

pooled_output
1,256 -> 38.811996

What should i try next to satisfy rtol criteria.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Apr 17, 2023

Hm, those are some fairly large discrepancies! The debugging process we recommend when something like that happens is:

  • Make a test environment and load the PT and TF models with the same weights
  • Try to isolate the earliest point where the model outputs diverge. You can use options like output_hidden_states to get the model to return all hidden states, not just the final ones.
  • Once you find the first point of divergence, try to see if you can dig into the layer where the divergence happened. You can place breakpoints, or extract sublayers and try passing test inputs into them.
  • Eventually you will find the single specific place where the divergence creeps in - now you can check what the cause is. Make sure the weights for that operation really do match between the two frameworks, and make sure both frameworks are doing the same thing at that point.

As always, if you can't figure it out, let me know! This kind of work can be quite gruelling, but we really appreciate the work you're doing on the model port.

@nandwalritik
Copy link
Contributor Author

Hi @Rocketknight1 I added test cases and fixed the feed forward part, but the CI is failing due to flax, I think this might not be related to my changes. Please review the PR and let me know if any more changes are required.

@Rocketknight1
Copy link
Member

Yep, those flax issues are unrelated, just ignore them. I'll review everything today, but the CI looks good!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good! I left a couple of minor comments, but this is basically ready to go at this point. The inference tests give me very high confidence that this matches the behaviour of the PT/FLAX model up to numerical error. Thanks for all the effort you put in with this PR - it's really appreciated, and I think people will get a lot of use out of the result!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR @nandwalritik - looks good from an audio perspective. Just wanted to confirm that the slow tests pass? I'm pretty confident we have equality between the TF and PT code based on what I've seen, but would love to hear what @Rocketknight1 says here too!

Edit: he beat me to it!

(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
) # these two operations makes sure that all values before the output lengths idxs are attended to
## check device
attention_mask = tf.tensor_scatter_nd_update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - will leave the TF specificities to @Rocketknight1!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct, as far as I can see! tensor_scatter_nd_update is TF's equivalent to JAX's array assignment with the .at[].set() operation.

@@ -1552,3 +1592,123 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)


class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also looks good to me (matching PyTorch) - will leave the TF specificities again to @Rocketknight1

def test_inference_keyword_spotting(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
input_data = self._load_superb("ks", 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super(b)! Excellent work on getting the slow tests to pass!

@nandwalritik
Copy link
Contributor Author

@sanchit-gandhi @Rocketknight1 let me know if any more changes are required or else can you guys get this pr merged.

@Rocketknight1
Copy link
Member

Just looked over the last few changes - I'm happy to merge it at this point. Thanks again for putting in the work on this!

@Rocketknight1 Rocketknight1 merged commit 20ac86c into huggingface:main Apr 26, 2023
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* Add initial changes for TF wav2vec2 for sequence classification

* Add suggested changes

* Add serving and serving output methods

* Add serving_output implementation and fix layer_weights

* Add fixes

* Fixed test cases

* Fixing test and adding suggested changes
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Add initial changes for TF wav2vec2 for sequence classification

* Add suggested changes

* Add serving and serving output methods

* Add serving_output implementation and fix layer_weights

* Add fixes

* Fixed test cases

* Fixing test and adding suggested changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants