-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TensorFlow Wav2Vec2 for sequence classification #22073
Add TensorFlow Wav2Vec2 for sequence classification #22073
Conversation
The documentation is not available anymore as the PR was closed or merged. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Kindly ping @sanchit-gandhi and adding @Rocketknight1 for the TensorFlow side. |
Hi @nandwalritik, and sorry for the extremely long delay in catching this! Ordinarily one of the TF maintainers reviews TF pull requests, but this one slipped through the cracks somehow. If you want to file TF PRs in future, you can directly ping me or @gante to make sure that we don't miss it. This PR actually looks almost perfect, but there are a couple of TF-specific details that are causing some tests to fail. I'll mark them in a code review in just a sec, but they shouldn't take too long to fix. Thanks again for submitting this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good! A few tweaks in the __init__
should fix most of the issues.
The only other thing missing is a serving
and serving_output
method. These are properties that are unique to TF models, and indicate the input and output signatures to enable model compilation and exporting. You can look at other models in the library to get a sense for how they work, but if you can't figure it out let me know and I'll add it for you!
94d74f8
to
8252b9f
Compare
for |
Hi @nandwalritik, I'm seeing the issue when you move it to I'll see if I can push a solution to your repo, hang on. |
Ok |
Try:
in the I also see a couple of other errors - you can see them by clicking the |
3da119b
to
983965e
Compare
Ok, so after adding this change, the weights are getting loaded without any warning or error, but the output of pytorch and tensorflow model doesn't have
What should i try next to satisfy rtol criteria. |
Hm, those are some fairly large discrepancies! The debugging process we recommend when something like that happens is:
As always, if you can't figure it out, let me know! This kind of work can be quite gruelling, but we really appreciate the work you're doing on the model port. |
Hi @Rocketknight1 I added test cases and fixed the feed forward part, but the CI is failing due to |
Yep, those flax issues are unrelated, just ignore them. I'll review everything today, but the CI looks good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good! I left a couple of minor comments, but this is basically ready to go at this point. The inference tests give me very high confidence that this matches the behaviour of the PT/FLAX model up to numerical error. Thanks for all the effort you put in with this PR - it's really appreciated, and I think people will get a lot of use out of the result!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR @nandwalritik - looks good from an audio perspective. Just wanted to confirm that the slow tests pass? I'm pretty confident we have equality between the TF and PT code based on what I've seen, but would love to hear what @Rocketknight1 says here too!
Edit: he beat me to it!
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask" | ||
) # these two operations makes sure that all values before the output lengths idxs are attended to | ||
## check device | ||
attention_mask = tf.tensor_scatter_nd_update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - will leave the TF specificities to @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct, as far as I can see! tensor_scatter_nd_update
is TF's equivalent to JAX's array assignment with the .at[].set()
operation.
@@ -1552,3 +1592,123 @@ def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: | |||
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None | |||
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None | |||
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) | |||
|
|||
|
|||
class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also looks good to me (matching PyTorch) - will leave the TF specificities again to @Rocketknight1
def test_inference_keyword_spotting(self): | ||
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True) | ||
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks") | ||
input_data = self._load_superb("ks", 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super(b)! Excellent work on getting the slow tests to pass!
@sanchit-gandhi @Rocketknight1 let me know if any more changes are required or else can you guys get this pr merged. |
Just looked over the last few changes - I'm happy to merge it at this point. Thanks again for putting in the work on this! |
* Add initial changes for TF wav2vec2 for sequence classification * Add suggested changes * Add serving and serving output methods * Add serving_output implementation and fix layer_weights * Add fixes * Fixed test cases * Fixing test and adding suggested changes
* Add initial changes for TF wav2vec2 for sequence classification * Add suggested changes * Add serving and serving output methods * Add serving_output implementation and fix layer_weights * Add fixes * Fixed test cases * Fixing test and adding suggested changes
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case. Add TensorFlow Wav2Vec2 for sequence classification #21778
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sanchit-gandhi
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.