diff --git a/README.md b/README.md
index daab3d1f9d6bbe..b1ad351337671b 100644
--- a/README.md
+++ b/README.md
@@ -518,6 +518,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
diff --git a/README_es.md b/README_es.md
index 9e1ac93b4a99ab..88eed018b8dc5e 100644
--- a/README_es.md
+++ b/README_es.md
@@ -493,6 +493,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
diff --git a/README_hd.md b/README_hd.md
index 92935efb589cee..8eb69b13df118a 100644
--- a/README_hd.md
+++ b/README_hd.md
@@ -467,6 +467,7 @@ conda install conda-forge::transformers
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise से) Jaehyeon Kim, Jungil Kong, Juhee Son. द्वाराअनुसंधान पत्र [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) के साथ जारी किया गया
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (फेसबुक एआई से) साथ में पेपर [wav2vec 2.0: ए फ्रेमवर्क फॉर सेल्फ-सुपरवाइज्ड लर्निंग ऑफ स्पीच रिप्रेजेंटेशन](https://arxiv.org/abs/2006.11477) एलेक्सी बेवस्की, हेनरी झोउ, अब्देलरहमान मोहम्मद, माइकल औली द्वारा।
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (Facebook AI से) साथ वाला पेपर [FAIRSEQ S2T: FAIRSEQ के साथ फास्ट स्पीच-टू-टेक्स्ट मॉडलिंग ](https://arxiv.org/abs/2010.05171) चांगहान वांग, यूं तांग, जुताई मा, ऐनी वू, सरव्या पोपुरी, दिमित्रो ओखोनको, जुआन पिनो द्वारा पोस्ट किया गया।
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (Facebook AI से) साथ वाला पेपर [सरल और प्रभावी जीरो-शॉट क्रॉस-लिंगुअल फोनेम रिकॉग्निशन](https://arxiv.org/abs/2109.11680) कियानटोंग जू, एलेक्सी बाएव्स्की, माइकल औली द्वारा।
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (माइक्रोसॉफ्ट रिसर्च से) पेपर के साथ जारी किया गया [WavLM: फुल स्टैक के लिए बड़े पैमाने पर स्व-पर्यवेक्षित पूर्व-प्रशिक्षण स्पीच प्रोसेसिंग](https://arxiv.org/abs/2110.13900) सानयुआन चेन, चेंगयी वांग, झेंगयांग चेन, यू वू, शुजी लियू, ज़ुओ चेन, जिन्यु ली, नाओयुकी कांडा, ताकुया योशियोका, ज़िओंग जिओ, जियान वू, लॉन्ग झोउ, शुओ रेन, यानमिन कियान, याओ कियान, जियान वू, माइकल ज़ेंग, फुरु वेई।
diff --git a/README_ja.md b/README_ja.md
index f43dda021c6f19..23fa0c2d5718ee 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -527,6 +527,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise から) Jaehyeon Kim, Jungil Kong, Juhee Son. から公開された研究論文 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103)
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (Facebook AI から) Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli から公開された研究論文: [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477)
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (Facebook AI から) Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino から公開された研究論文: [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171)
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (Facebook AI から) Qiantong Xu, Alexei Baevski, Michael Auli から公開された研究論文: [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680)
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (Microsoft Research から) Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei から公開された研究論文: [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900)
diff --git a/README_ko.md b/README_ko.md
index c2e53a1b81ce95..a1fb66492dad99 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -442,6 +442,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise 에서 제공)은 Jaehyeon Kim, Jungil Kong, Juhee Son.의 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103)논문과 함께 발표했습니다.
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (Facebook AI 에서) Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 의 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 논문과 함께 발표했습니다.
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (Facebook AI 에서) Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino 의 [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) 논문과 함께 발표했습니다.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (Facebook AI 에서) Qiantong Xu, Alexei Baevski, Michael Auli 의 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 논문과 함께 발표했습니다.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (Microsoft Research 에서) Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei 의 [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) 논문과 함께 발표했습니다.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 972f3a386f420e..b14f8f61050d2b 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -466,6 +466,7 @@ conda install conda-forge::transformers
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (来自 Kakao Enterprise) 伴随论文 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) 由 Jaehyeon Kim, Jungil Kong, Juhee Son 发布。
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (来自 Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) 由 Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (来自 Facebook AI) 伴随论文 [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) 由 Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino 发布。
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (来自 Facebook AI) 伴随论文 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 由 Qiantong Xu, Alexei Baevski, Michael Auli 发布。
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
diff --git a/README_zh-hant.md b/README_zh-hant.md
index b17c8946bc3e30..c03013b8baaaa5 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -478,6 +478,7 @@ conda install conda-forge::transformers
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
+1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 86cffb9a7e35cf..a1e0965121d7ef 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -648,6 +648,8 @@
title: VITS
- local: model_doc/wav2vec2
title: Wav2Vec2
+ - local: model_doc/wav2vec2-bert
+ title: Wav2Vec2-BERT
- local: model_doc/wav2vec2-conformer
title: Wav2Vec2-Conformer
- local: model_doc/wav2vec2_phoneme
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 52b5df6e59ba14..421431fedb0000 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -295,6 +295,7 @@ Flax), PyTorch, and/or TensorFlow.
| [VITS](model_doc/vits) | ✅ | ❌ | ❌ |
| [ViViT](model_doc/vivit) | ✅ | ❌ | ❌ |
| [Wav2Vec2](model_doc/wav2vec2) | ✅ | ✅ | ✅ |
+| [Wav2Vec2-BERT](model_doc/wav2vec2-bert) | ✅ | ❌ | ❌ |
| [Wav2Vec2-Conformer](model_doc/wav2vec2-conformer) | ✅ | ❌ | ❌ |
| [Wav2Vec2Phoneme](model_doc/wav2vec2_phoneme) | ✅ | ✅ | ✅ |
| [WavLM](model_doc/wavlm) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/wav2vec2-bert.md b/docs/source/en/model_doc/wav2vec2-bert.md
new file mode 100644
index 00000000000000..6514133330a9d4
--- /dev/null
+++ b/docs/source/en/model_doc/wav2vec2-bert.md
@@ -0,0 +1,90 @@
+
+
+# Wav2Vec2-BERT
+
+## Overview
+
+The Wav2Vec2-BERT model was proposed in [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team from Meta AI.
+
+This model was pre-trained on 4.5M hours of unlabeled audio data covering more than 143 languages. It requires finetuning to be used for downstream tasks such as Automatic Speech Recognition (ASR), or Audio Classification.
+
+The official results of the model can be found in Section 3.2.1 of the paper.
+
+The abstract from the paper is the following:
+
+*Recent advancements in automatic speech translation have dramatically expanded language coverage, improved multimodal capabilities, and enabled a wide range of tasks and functionalities. That said, large-scale automatic speech translation systems today lack key features that help machine-mediated communication feel seamless when compared to human-to-human dialogue. In this work, we introduce a family of models that enable end-to-end expressive and multilingual translations in a streaming fashion. First, we contribute an improved version of the massively multilingual and multimodal SeamlessM4T model—SeamlessM4T v2. This newer model, incorporating an updated UnitY2 framework, was trained on more low-resource language data. The expanded version of SeamlessAlign adds 114,800 hours of automatically aligned data for a total of 76 languages. SeamlessM4T v2 provides the foundation on which our two newest models, SeamlessExpressive and SeamlessStreaming, are initiated. SeamlessExpressive enables translation that preserves vocal styles and prosody. Compared to previous efforts in expressive speech research, our work addresses certain underexplored aspects of prosody, such as speech rate and pauses, while also preserving the style of one’s voice. As for SeamlessStreaming, our model leverages the Efficient Monotonic Multihead Attention (EMMA) mechanism to generate low-latency target translations without waiting for complete source utterances. As the first of its kind, SeamlessStreaming enables simultaneous speech-to-speech/text translation for multiple source and target languages. To understand the performance of these models, we combined novel and modified versions of existing automatic metrics to evaluate prosody, latency, and robustness. For human evaluations, we adapted existing protocols tailored for measuring the most relevant attributes in the preservation of meaning, naturalness, and expressivity. To ensure that our models can be used safely and responsibly, we implemented the first known red-teaming effort for multimodal machine translation, a system for the detection and mitigation of added toxicity, a systematic evaluation of gender bias, and an inaudible localized watermarking mechanism designed to dampen the impact of deepfakes. Consequently, we bring major components from SeamlessExpressive and SeamlessStreaming together to form Seamless, the first publicly available system that unlocks expressive cross-lingual communication in real-time. In sum, Seamless gives us a pivotal look at the technical foundation needed to turn the Universal Speech Translator from a science fiction concept into a real-world technology. Finally, contributions in this work—including models, code, and a watermark detector—are publicly released and accessible at the link below.*
+
+This model was contributed by [ylacombe](https://huggingface.co/ylacombe). The original code can be found [here](https://github.com/facebookresearch/seamless_communication).
+
+## Usage tips
+
+- Wav2Vec2-BERT follows the same architecture as Wav2Vec2-Conformer, but employs a causal depthwise convolutional layer and uses as input a mel-spectrogram representation of the audio instead of the raw waveform.
+- Wav2Vec2-BERT can use either no relative position embeddings, Shaw-like position embeddings, Transformer-XL-like position embeddings, or
+ rotary position embeddings by setting the correct `config.position_embeddings_type`.
+- Wav2Vec2-BERT also introduces a Conformer-based adapter network instead of a simple convolutional network.
+
+## Resources
+
+
+
+- [`Wav2Vec2BertForCTC`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition).
+- You can also adapt these notebooks on [how to finetune a speech recognition model in English](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/speech_recognition.ipynb), and [how to finetune a speech recognition model in any language](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multi_lingual_speech_recognition.ipynb).
+
+
+
+- [`Wav2Vec2BertForSequenceClassification`] can be used by adapting this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/audio-classification).
+- See also: [Audio classification task guide](../tasks/audio_classification)
+
+
+## Wav2Vec2BertConfig
+
+[[autodoc]] Wav2Vec2BertConfig
+
+## Wav2Vec2BertProcessor
+
+[[autodoc]] Wav2Vec2BertProcessor
+ - __call__
+ - pad
+ - from_pretrained
+ - save_pretrained
+ - batch_decode
+ - decode
+
+## Wav2Vec2BertModel
+
+[[autodoc]] Wav2Vec2BertModel
+ - forward
+
+## Wav2Vec2BertForCTC
+
+[[autodoc]] Wav2Vec2BertForCTC
+ - forward
+
+## Wav2Vec2BertForSequenceClassification
+
+[[autodoc]] Wav2Vec2BertForSequenceClassification
+ - forward
+
+## Wav2Vec2BertForAudioFrameClassification
+
+[[autodoc]] Wav2Vec2BertForAudioFrameClassification
+ - forward
+
+## Wav2Vec2BertForXVector
+
+[[autodoc]] Wav2Vec2BertForXVector
+ - forward
diff --git a/docs/source/en/tasks/asr.md b/docs/source/en/tasks/asr.md
index d01269ba60a696..737460ed297bcf 100644
--- a/docs/source/en/tasks/asr.md
+++ b/docs/source/en/tasks/asr.md
@@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit
-[Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [M-CTC-T](../model_doc/mctct), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm)
+[Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [M-CTC-T](../model_doc/mctct), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-BERT](../model_doc/wav2vec2-bert), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm)
diff --git a/docs/source/en/tasks/audio_classification.md b/docs/source/en/tasks/audio_classification.md
index 743a797fc53fa8..678af90c4fa079 100644
--- a/docs/source/en/tasks/audio_classification.md
+++ b/docs/source/en/tasks/audio_classification.md
@@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit
-[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm), [Whisper](../model_doc/whisper)
+[Audio Spectrogram Transformer](../model_doc/audio-spectrogram-transformer), [Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-BERT](../model_doc/wav2vec2-bert), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm), [Whisper](../model_doc/whisper)
diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
index 1c658904e71e30..3ca9a2c6f44d3e 100755
--- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
@@ -132,6 +132,13 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
+ add_adapter: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very"
+ "useful to downsample the output length."
+ },
+ )
@dataclass
@@ -602,6 +609,7 @@ def remove_special_characters(batch):
"pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout,
+ "add_adapter": model_args.add_adapter,
}
)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 4941d724455dfe..112d84c7099fb2 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -909,6 +909,11 @@
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
+ "models.wav2vec2_bert": [
+ "WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2BertConfig",
+ "Wav2Vec2BertProcessor",
+ ],
"models.wav2vec2_conformer": [
"WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2ConformerConfig",
@@ -3501,6 +3506,17 @@
"Wav2Vec2PreTrainedModel",
]
)
+ _import_structure["models.wav2vec2_bert"].extend(
+ [
+ "WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2BertForAudioFrameClassification",
+ "Wav2Vec2BertForCTC",
+ "Wav2Vec2BertForSequenceClassification",
+ "Wav2Vec2BertForXVector",
+ "Wav2Vec2BertModel",
+ "Wav2Vec2BertPreTrainedModel",
+ ]
+ )
_import_structure["models.wav2vec2_conformer"].extend(
[
"WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -5602,6 +5618,11 @@
Wav2Vec2Processor,
Wav2Vec2Tokenizer,
)
+ from .models.wav2vec2_bert import (
+ WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ Wav2Vec2BertConfig,
+ Wav2Vec2BertProcessor,
+ )
from .models.wav2vec2_conformer import (
WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2ConformerConfig,
@@ -7799,6 +7820,15 @@
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
+ from .models.wav2vec2_bert import (
+ WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2BertForAudioFrameClassification,
+ Wav2Vec2BertForCTC,
+ Wav2Vec2BertForSequenceClassification,
+ Wav2Vec2BertForXVector,
+ Wav2Vec2BertModel,
+ Wav2Vec2BertPreTrainedModel,
+ )
from .models.wav2vec2_conformer import (
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ConformerForAudioFrameClassification,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 2c20873c2ed79d..cf7965f6e8cf3c 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -234,6 +234,7 @@
vits,
vivit,
wav2vec2,
+ wav2vec2_bert,
wav2vec2_conformer,
wav2vec2_phoneme,
wav2vec2_with_lm,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 9eb3f1985c8536..b6b1ab3c3ee60d 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -245,6 +245,7 @@
("vits", "VitsConfig"),
("vivit", "VivitConfig"),
("wav2vec2", "Wav2Vec2Config"),
+ ("wav2vec2-bert", "Wav2Vec2BertConfig"),
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
("wavlm", "WavLMConfig"),
("whisper", "WhisperConfig"),
@@ -457,6 +458,7 @@
("vits", "VITS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vivit", "VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("wav2vec2-bert", "WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2-conformer", "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("whisper", "WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xclip", "XCLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -715,6 +717,7 @@
("vits", "VITS"),
("vivit", "ViViT"),
("wav2vec2", "Wav2Vec2"),
+ ("wav2vec2-bert", "Wav2Vec2-BERT"),
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
("wavlm", "WavLM"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 457217566e7cfa..b3461e8b56a7a9 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -100,6 +100,7 @@
("vit_mae", "ViTFeatureExtractor"),
("vit_msn", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
+ ("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
("wavlm", "Wav2Vec2FeatureExtractor"),
("whisper", "WhisperFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 7bf50a4518fa88..60f71a2f5abb08 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -231,6 +231,7 @@
("vits", "VitsModel"),
("vivit", "VivitModel"),
("wav2vec2", "Wav2Vec2Model"),
+ ("wav2vec2-bert", "Wav2Vec2BertModel"),
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
("wavlm", "WavLMModel"),
("whisper", "WhisperModel"),
@@ -1031,6 +1032,7 @@
("unispeech", "UniSpeechForSequenceClassification"),
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
+ ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"),
("whisper", "WhisperForAudioClassification"),
@@ -1048,6 +1050,7 @@
("unispeech", "UniSpeechForCTC"),
("unispeech-sat", "UniSpeechSatForCTC"),
("wav2vec2", "Wav2Vec2ForCTC"),
+ ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
("wavlm", "WavLMForCTC"),
]
@@ -1059,6 +1062,7 @@
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
+ ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
]
@@ -1070,6 +1074,7 @@
("data2vec-audio", "Data2VecAudioForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wav2vec2", "Wav2Vec2ForXVector"),
+ ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
("wavlm", "WavLMForXVector"),
]
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index eee8af931e99ac..59dbcb53f5187d 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -90,6 +90,7 @@
("vipllava", "LlavaProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
+ ("wav2vec2-bert", "Wav2Vec2Processor"),
("wav2vec2-conformer", "Wav2Vec2Processor"),
("wavlm", "Wav2Vec2Processor"),
("whisper", "WhisperProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index ac09eecd1e0e99..05357a3e7175a7 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -411,6 +411,7 @@
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vits", ("VitsTokenizer", None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
+ ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/wav2vec2_bert/__init__.py b/src/transformers/models/wav2vec2_bert/__init__.py
new file mode 100644
index 00000000000000..594f108bcaad96
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/__init__.py
@@ -0,0 +1,70 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_wav2vec2_bert": [
+ "WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2BertConfig",
+ ],
+ "processing_wav2vec2_bert": ["Wav2Vec2BertProcessor"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_wav2vec2_bert"] = [
+ "WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2BertForAudioFrameClassification",
+ "Wav2Vec2BertForCTC",
+ "Wav2Vec2BertForSequenceClassification",
+ "Wav2Vec2BertForXVector",
+ "Wav2Vec2BertModel",
+ "Wav2Vec2BertPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_wav2vec2_bert import (
+ WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ Wav2Vec2BertConfig,
+ )
+ from .processing_wav2vec2_bert import Wav2Vec2BertProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_wav2vec2_bert import (
+ WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2BertForAudioFrameClassification,
+ Wav2Vec2BertForCTC,
+ Wav2Vec2BertForSequenceClassification,
+ Wav2Vec2BertForXVector,
+ Wav2Vec2BertModel,
+ Wav2Vec2BertPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py
new file mode 100644
index 00000000000000..12593107ef939d
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py
@@ -0,0 +1,314 @@
+# coding=utf-8
+# Copyright 2024 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Wav2Vec2Bert model configuration"""
+
+import functools
+import operator
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+WAV2VEC2_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/w2v-bert-2.0": "https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/config.json",
+}
+
+
+class Wav2Vec2BertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Wav2Vec2BertModel`]. It is used to
+ instantiate an Wav2Vec2Bert model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Bert
+ [facebook/wav2vec2-bert-rel-pos-large](https://huggingface.co/facebook/wav2vec2-bert-rel-pos-large)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*):
+ Vocabulary size of the Wav2Vec2Bert model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`Wav2Vec2BertModel`]. Vocabulary size of the
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
+ method of [`Wav2Vec2BertModel`].
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ feature_projection_input_dim (`int`, *optional*, defaults to 160):
+ Input dimension of this model, i.e the dimension after processing input audios with [`SeamlessM4TFeatureExtractor`] or [`Wav2Vec2BertProcessor`].
+ hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for the feature projection.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`Wav2Vec2BertForCTC`].
+ layerdrop (`float`, *optional*, defaults to 0.1):
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
+ details.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
+ Recognition](https://arxiv.org/abs/1904.08779).
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+ procecure generates `mask_time_prob*len(time_axis)/mask_time_length ``independent masks over the axis. If
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
+ mask_time_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the time axis.
+ mask_time_min_masks (`int`, *optional*, defaults to 2):
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+ irrespectively of `mask_feature_prob`. Only relevant if `mask_time_prob*len(time_axis)/mask_time_length <
+ mask_time_min_masks`.
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+ masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+ True`.
+ mask_feature_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the feature axis.
+ mask_feature_min_masks (`int`, *optional*, defaults to 0):
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+ step, irrespectively of `mask_feature_prob`. Only relevant if
+ `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`Wav2Vec2BertForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`Wav2Vec2BertForCTC`].
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+ instance of [`Wav2Vec2BertForSequenceClassification`].
+ classifier_proj_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the projection before token mean-pooling for classification.
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
+ xvector_output_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of the *XVector* embedding vectors.
+ pad_token_id (`int`, *optional*, defaults to 0): The id of the _beginning-of-stream_ token.
+ bos_token_id (`int`, *optional*, defaults to 1): The id of the _padding_ token.
+ eos_token_id (`int`, *optional*, defaults to 2): The id of the _end-of-stream_ token.
+ add_adapter (`bool`, *optional*, defaults to `False`):
+ Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very
+ useful for warm-starting Wav2Vec2Bert for SpeechEncoderDecoder models.
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ adapter_stride (`int`, *optional*, defaults to 2):
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ num_adapter_layers (`int`, *optional*, defaults to 1):
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
+ True`.
+ adapter_act (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the adapter layers. If string, `"gelu"`,
+ `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
+ use_intermediate_ffn_before_adapter (`bool`, *optional*, defaults to `False`):
+ Whether an intermediate feed-forward block should be stacked on top of the Wav2Vec2Bert Encoder and before the adapter network.
+ Only relevant if `add_adapter is True`.
+ output_hidden_size (`int`, *optional*):
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
+ if `add_adapter is True`.
+ position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`):
+ Can be specified to :
+ - `rotary`, for rotary position embeddings.
+ - `relative`, for relative position embeddings.
+ - `relative_key`, for relative position embeddings as defined by Shaw in [Self-Attention
+ with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ If left to `None`, no relative position embeddings is applied.
+ rotary_embedding_base (`int`, *optional*, defaults to 10000):
+ If `"rotary"` position embeddings are used, defines the size of the embedding base.
+ max_source_positions (`int`, *optional*, defaults to 5000):
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
+ left_max_position_embeddings (`int`, *optional*, defaults to 64):
+ If `"relative_key"` (aka Shaw) position embeddings are used, defines the left clipping value for relative positions.
+ right_max_position_embeddings (`int`, *optional*, defaults to 8):
+ If `"relative_key"` (aka Shaw) position embeddings are used, defines the right clipping value for relative positions.
+ conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks.
+ conformer_conv_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all convolutional layers in Conformer blocks.
+ Example:
+
+ ```python
+ >>> from transformers import Wav2Vec2BertConfig, Wav2Vec2BertModel
+
+ >>> # Initializing a Wav2Vec2Bert facebook/wav2vec2-bert-rel-pos-large style configuration
+ >>> configuration = Wav2Vec2BertConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/wav2vec2-bert-rel-pos-large style configuration
+ >>> model = Wav2Vec2BertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "wav2vec2-bert"
+
+ def __init__(
+ self,
+ vocab_size=None,
+ hidden_size=1024,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ intermediate_size=4096,
+ feature_projection_input_dim=160,
+ hidden_act="swish",
+ hidden_dropout=0.0,
+ activation_dropout=0.0,
+ attention_dropout=0.0,
+ feat_proj_dropout=0.0,
+ final_dropout=0.1,
+ layerdrop=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ apply_spec_augment=True,
+ mask_time_prob=0.05,
+ mask_time_length=10,
+ mask_time_min_masks=2,
+ mask_feature_prob=0.0,
+ mask_feature_length=10,
+ mask_feature_min_masks=0,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ use_weighted_layer_sum=False,
+ classifier_proj_size=768,
+ tdnn_dim=(512, 512, 512, 512, 1500),
+ tdnn_kernel=(5, 3, 3, 1, 1),
+ tdnn_dilation=(1, 2, 3, 1, 1),
+ xvector_output_dim=512,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ add_adapter=False,
+ adapter_kernel_size=3,
+ adapter_stride=2,
+ num_adapter_layers=1,
+ adapter_act="relu",
+ use_intermediate_ffn_before_adapter=False,
+ output_hidden_size=None,
+ position_embeddings_type="relative_key",
+ rotary_embedding_base=10000,
+ max_source_positions=5000,
+ left_max_position_embeddings=64,
+ right_max_position_embeddings=8,
+ conv_depthwise_kernel_size=31,
+ conformer_conv_dropout=0.1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_attention_heads = num_attention_heads
+ self.feature_projection_input_dim = feature_projection_input_dim
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.feat_proj_dropout = feat_proj_dropout
+ self.final_dropout = final_dropout
+ self.layerdrop = layerdrop
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.use_weighted_layer_sum = use_weighted_layer_sum
+ self.max_source_positions = max_source_positions
+
+ if position_embeddings_type is not None and position_embeddings_type not in [
+ "rotary",
+ "relative",
+ "relative_key",
+ ]:
+ raise ValueError(
+ """
+ `position_embeddings_type` is not valid. It must be one of the following values:
+ `["rotary", "relative", "relative_key"]` or left as `None`.
+ """
+ )
+ self.position_embeddings_type = position_embeddings_type
+ self.rotary_embedding_base = rotary_embedding_base
+ self.left_max_position_embeddings = left_max_position_embeddings
+ self.right_max_position_embeddings = right_max_position_embeddings
+
+ # Conformer-block related
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
+ self.conformer_conv_dropout = conformer_conv_dropout
+
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
+ self.apply_spec_augment = apply_spec_augment
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.mask_time_min_masks = mask_time_min_masks
+ self.mask_feature_prob = mask_feature_prob
+ self.mask_feature_length = mask_feature_length
+ self.mask_feature_min_masks = mask_feature_min_masks
+
+ # ctc loss
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # adapter
+ self.add_adapter = add_adapter
+ self.adapter_kernel_size = adapter_kernel_size
+ self.adapter_stride = adapter_stride
+ self.num_adapter_layers = num_adapter_layers
+ self.adapter_act = adapter_act
+ self.output_hidden_size = output_hidden_size if output_hidden_size is not None else hidden_size
+ if use_intermediate_ffn_before_adapter and not add_adapter:
+ raise ValueError("`use_intermediate_ffn_before_adapter` is `True` but `add_adapter` is `False`.")
+ self.use_intermediate_ffn_before_adapter = use_intermediate_ffn_before_adapter
+
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
+ self.classifier_proj_size = classifier_proj_size
+
+ # XVector-specific parameters. Feel free to ignore for other classes.
+ self.tdnn_dim = list(tdnn_dim)
+ self.tdnn_kernel = list(tdnn_kernel)
+ self.tdnn_dilation = list(tdnn_dilation)
+ self.xvector_output_dim = xvector_output_dim
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return functools.reduce(operator.mul, self.conv_stride, 1)
diff --git a/src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py b/src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
new file mode 100644
index 00000000000000..8b77cd71f7f7e0
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
@@ -0,0 +1,218 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Wav2Vec2Bert BERT checkpoint."""
+
+
+import argparse
+
+import torch
+import torchaudio
+from fairseq2.data import Collater
+from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
+
+from transformers import (
+ SeamlessM4TFeatureExtractor,
+ Wav2Vec2BertConfig,
+ Wav2Vec2BertModel,
+ logging,
+)
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+wav2vec_convert_list = [
+ ("encoder_frontend.model_dim_proj", "feature_projection.projection"),
+ ("encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
+ ("encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
+ ("encoder.inner.layers", "encoder.layers"),
+ ("encoder.inner_layer_norm", "encoder.layer_norm"),
+ ("encoder.adaptor_layers", "adapter.layers"),
+ ("inner_proj", "intermediate_dense"),
+ ("self_attn.output_proj", "self_attn.linear_out"),
+ ("output_proj", "output_dense"),
+ ("self_attn.k_proj", "self_attn.linear_k"),
+ ("self_attn.v_proj", "self_attn.linear_v"),
+ ("self_attn.q_proj", "self_attn.linear_q"),
+ ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
+ ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
+ ("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"),
+ ("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
+ ("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
+ ("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
+ ("conv.depthwise_conv", "conv_module.depthwise_conv"),
+ ("conv.layer_norm", "conv_module.depthwise_layer_norm"),
+ ("conv_layer_norm", "conv_module.layer_norm"),
+ ("encoder.proj1", "intermediate_ffn.intermediate_dense"),
+ ("encoder.proj2", "intermediate_ffn.output_dense"),
+ ("encoder.layer_norm", "inner_layer_norm"),
+ ("masker.temporal_mask_embed", "masked_spec_embed"),
+]
+
+keys_to_remove = {
+ "quantizer.entry_proj",
+ "final_proj",
+ "final_target_proj",
+ "quantizer.entries",
+ "quantizer.num_updates",
+}
+
+
+def param_count(model):
+ return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
+
+
+def _convert_model(
+ original_model,
+ hf_model,
+ convert_list,
+):
+ state_dict = original_model.state_dict()
+
+ for k, v in list(state_dict.items()):
+ new_key = k
+ for old_layer_name, new_layer_name in convert_list:
+ if old_layer_name in new_key:
+ new_key = new_key.replace(old_layer_name, new_layer_name)
+
+ # must do it by hand
+ if ".layer_norm" in new_key and new_key.split(".layer_norm")[0][-1].isnumeric():
+ new_key = new_key.replace("layer_norm", "final_layer_norm")
+
+ add_key = True
+ for key in keys_to_remove:
+ if key in new_key:
+ state_dict.pop(k)
+ add_key = False
+ break
+
+ if add_key:
+ state_dict[new_key] = state_dict.pop(k)
+
+ extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
+ extra_keys = set({k for k in extra_keys if "num_updates" not in k}) # filter unecessary param
+ missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
+ if len(extra_keys) != 0:
+ raise ValueError(f"extra keys found: {extra_keys}")
+ if len(missing_keys) != 0:
+ raise ValueError(f"missing keys: {missing_keys}")
+ hf_model.load_state_dict(state_dict, strict=True)
+ n_params = param_count(hf_model)
+
+ logger.info(f"model loaded: {round(n_params/1e6,1)}M params")
+
+ hf_model.eval()
+ del state_dict
+
+ return hf_model
+
+
+@torch.no_grad()
+def convert_wav2vec2_bert_checkpoint(
+ checkpoint_path,
+ pytorch_dump_folder_path,
+ config_path=None,
+ repo_id=None,
+):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = Wav2Vec2BertConfig.from_pretrained(config_path, hidden_act="swish")
+ else:
+ config = Wav2Vec2BertConfig(apply_spec_augment=False)
+
+ hf_wav2vec = Wav2Vec2BertModel(config)
+
+ model = load_conformer_shaw_model(checkpoint_path, dtype=torch.float32)
+ model.eval()
+
+ hf_wav2vec = _convert_model(model, hf_wav2vec, wav2vec_convert_list)
+
+ hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
+
+ if repo_id:
+ hf_wav2vec.push_to_hub(repo_id, create_pr=True)
+
+ # save feature extractor
+ fe = SeamlessM4TFeatureExtractor(padding_value=1)
+ fe._set_processor_class("Wav2Vec2BertProcessor")
+ fe.save_pretrained(pytorch_dump_folder_path)
+
+ if repo_id:
+ fe.push_to_hub(repo_id, create_pr=True)
+
+ if args.audio_path:
+ waveform, sample_rate = torchaudio.load(args.audio_path)
+ waveform = torchaudio.functional.resample(waveform, sample_rate, fe.sampling_rate)
+
+ fbank_converter = WaveformToFbankConverter(
+ num_mel_bins=80,
+ waveform_scale=2**15,
+ channel_last=True,
+ standardize=True,
+ dtype=torch.float32,
+ )
+ collater = Collater(pad_value=1)
+
+ decoded_audio = {"waveform": waveform.T, "sample_rate": fe.sampling_rate, "format": -1}
+ src = collater(fbank_converter(decoded_audio))["fbank"]
+ seqs, padding_mask = get_seqs_and_padding_mask(src)
+
+ with torch.inference_mode():
+ seqs, padding_mask = model.encoder_frontend(seqs, padding_mask)
+ original_output, padding_mask = model.encoder(seqs, padding_mask)
+
+ hf_wav2vec.eval()
+
+ inputs = fe(waveform, return_tensors="pt", padding=True)
+ with torch.no_grad():
+ outputs = hf_wav2vec(**inputs)
+
+ torch.testing.assert_close(original_output, outputs.last_hidden_state, atol=5e-3, rtol=5e-3)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model.",
+ )
+ parser.add_argument(
+ "--checkpoint_path", default="conformer_shaw", type=str, help="Path to seamless communication checkpoint"
+ )
+ parser.add_argument(
+ "--config_path",
+ default=None,
+ type=str,
+ help="Path to hf config.json of model to convert",
+ )
+ parser.add_argument("--repo_id", default=None, type=str, help="Push to this repo id if precised.")
+ parser.add_argument(
+ "--audio_path",
+ default=None,
+ type=str,
+ help="If specified, check that the original model and the converted model produce the same outputs.",
+ )
+
+ args = parser.parse_args()
+ convert_wav2vec2_bert_checkpoint(
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.repo_id
+ )
diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
new file mode 100644
index 00000000000000..034da900ee8ab3
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
@@ -0,0 +1,1667 @@
+# coding=utf-8
+# Copyright 2024 The Seamless Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-BERT model."""
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2BertConfig"
+
+# Base docstring
+_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
+_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
+_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
+_CTC_EXPECTED_LOSS = 17.04
+
+
+WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/w2v-bert-2.0",
+ # See all Wav2Vec2-BERT models at https://huggingface.co/models?filter=wav2vec2-bert
+]
+
+
+# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
+def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
+ """
+ Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
+ stops at the corresponding element in `seq_lens`.
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
+ The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
+ seq_lens (`torch.Tensor` of shape `(batch)`:
+ Each element represents the length of the sequence at the same index in `hidden_states`
+ Returns:
+ `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
+ """
+ batch_size, mask_seq_len = hidden_states.shape[:2]
+
+ indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
+
+ bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
+
+ mask = hidden_states.new_ones((batch_size, mask_seq_len))
+
+ mask = mask.masked_fill(bool_mask, 0)
+
+ return mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.sum(-1).detach().tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+ """
+ Sample `num_negatives` vectors from feature vectors.
+ """
+ batch_size, sequence_length = features_shape
+
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+ sequence_length_range = np.arange(sequence_length)
+
+ # get `num_negatives` random vector indices from the same utterance
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+ mask_time_indices = (
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
+ )
+
+ for batch_idx in range(batch_size):
+ high = mask_time_indices[batch_idx].sum() - 1
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+ # avoid sampling the same positive vector, but keep the distribution uniform
+ sampled_indices[sampled_indices >= feature_indices] += 1
+
+ # remap to actual indices
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+ # correct for batch size
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+ return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
+class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ # Ignore copy
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ # Embeddings are computed in the dtype of the inv_freq constant
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ # Computed embeddings are cast to the dtype of the hidden state inputs
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
+ return self.cached_rotary_positional_embedding
+
+
+# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
+class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # Pad the sequence entirely on the left because of causal convolution.
+ hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0))
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+
+ hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2BertSelfAttention(nn.Module):
+ """Construct an Wav2Vec2BertSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config, is_adapter_attention=False):
+ super().__init__()
+ hidden_size = config.hidden_size if not is_adapter_attention else config.output_hidden_size
+
+ self.head_size = hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type if not is_adapter_attention else None
+
+ self.linear_q = nn.Linear(hidden_size, hidden_size)
+ self.linear_k = nn.Linear(hidden_size, hidden_size)
+ self.linear_v = nn.Linear(hidden_size, hidden_size)
+ self.linear_out = nn.Linear(hidden_size, hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ if self.position_embeddings_type == "relative_key":
+ self.left_max_position_embeddings = config.left_max_position_embeddings
+ self.right_max_position_embeddings = config.right_max_position_embeddings
+ num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1
+ self.distance_embedding = nn.Embedding(num_positions, self.head_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ if self.position_embeddings_type == "relative":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
+ " 'relative'"
+ )
+ # apply relative_position_embeddings to qk scores
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
+ scores = self._apply_relative_embeddings(
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
+ )
+ else:
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+
+ if self.position_embeddings_type == "relative_key":
+ query_length, key_length = query.shape[2], key.shape[2]
+
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_r - position_ids_l
+ distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings)
+
+ positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings)
+ positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
+
+ relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+ scores = scores + (relative_position_attn_weights / math.sqrt(self.head_size))
+
+ # apply attention_mask if necessary
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ # => (batch, head, time1, time2)
+ probs = torch.softmax(scores, dim=-1)
+ probs = self.dropout(probs)
+
+ # => (batch, head, time1, d_k)
+ hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2BertEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn1 = Wav2Vec2BertFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.self_attn_dropout = nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2BertSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2BertConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn2 = Wav2Vec2BertFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ conv_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2BertRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2BertRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ conv_attention_mask = attention_mask
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ output_attentions,
+ conv_attention_mask,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ conv_attention_mask=conv_attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Wav2Vec2BertAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.proj = self.proj_layer_norm = None
+ self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ self.kernel_size = config.adapter_kernel_size
+ self.stride = config.adapter_stride
+
+ def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
+ if seq_lens is None:
+ return seq_lens
+ pad = self.kernel_size // 2
+ seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
+ return seq_lens.floor()
+
+ def forward(self, hidden_states, attention_mask=None):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ sub_sampled_lengths = None
+ if attention_mask is not None:
+ sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
+
+ for layer in self.layers:
+ layerdrop_prob = torch.rand([])
+ sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(
+ hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
+ )
+
+ return hidden_states
+
+
+class Wav2Vec2BertAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.output_hidden_size
+ dropout = config.conformer_conv_dropout
+
+ self.kernel_size = config.adapter_kernel_size
+ self.stride = config.adapter_stride
+
+ # 1. residual convolution
+ self.residual_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.residual_conv = nn.Conv1d(
+ embed_dim,
+ 2 * embed_dim,
+ self.kernel_size,
+ stride=self.stride,
+ padding=self.stride // 2,
+ )
+ self.activation = nn.GLU(dim=1)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.self_attn_conv = nn.Conv1d(
+ embed_dim,
+ 2 * embed_dim,
+ self.kernel_size,
+ stride=self.stride,
+ padding=self.stride // 2,
+ )
+ self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True)
+ self.self_attn_dropout = nn.Dropout(dropout)
+
+ # Feed-forward
+ self.ffn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ sub_sampled_lengths: Optional[torch.Tensor] = None,
+ ):
+ residual = self.residual_layer_norm(hidden_states)
+
+ # Apply pooling to the residual to match the sequence length of the
+ # multi-head attention output.
+ # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
+ residual = residual.transpose(1, 2)
+ residual = self.residual_conv(residual)
+ residual = self.activation(residual)
+ # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
+ residual = residual.transpose(1, 2)
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ # Apply pooling before feeding to the multihead-attention layer.
+ # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.self_attn_conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ if attention_mask is not None:
+ attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask,
+ hidden_states.dtype,
+ )
+
+ # The rest of the computation is identical to a vanilla Transformer
+ # encoder layer.
+ hidden_states, attn_weigths = self.self_attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ residual = hidden_states
+
+ hidden_states = self.ffn_layer_norm(hidden_states)
+ hidden_states = self.ffn(hidden_states) + residual
+
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerPreTrainedModel with Wav2Vec2Conformer->Wav2Vec2Bert,wav2vec2_conformer->wav2vec2_bert, input_values->input_features
+class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2BertConfig
+ base_model_prefix = "wav2vec2_bert"
+ main_input_name = "input_features"
+ supports_gradient_checkpointing = True
+
+ # Ignore copy
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Wav2Vec2BertSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2BertFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ # Ignore copy
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride, padding):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length + 2 * padding - kernel_size, stride, rounding_mode="floor") + 1
+
+ if add_adapter:
+ padding = self.config.adapter_kernel_size // 2
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(
+ input_lengths, self.config.adapter_kernel_size, self.config.adapter_stride, padding
+ )
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+WAV2VEC2_BERT_START_DOCSTRING = r"""
+ Wav2Vec2Bert was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`Wav2Vec2BertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
+ def __init__(self, config: Wav2Vec2BertConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_projection = Wav2Vec2BertFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2BertEncoder(config)
+
+ self.adapter = Wav2Vec2BertAdapter(config) if config.add_adapter else None
+
+ self.intermediate_ffn = None
+ if config.use_intermediate_ffn_before_adapter:
+ self.intermediate_ffn = Wav2Vec2BertFeedForward(config, act_fn="relu")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states, extract_features = self.feature_projection(input_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.intermediate_ffn:
+ expanded_hidden_states = self.intermediate_ffn(hidden_states)
+ hidden_states = hidden_states + 0.5 * expanded_hidden_states
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states, attention_mask=attention_mask)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForCTC.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
+ def __init__(self, config, target_lang: Optional[str] = None):
+ super().__init__(config)
+
+ self.wav2vec2_bert = Wav2Vec2BertModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ self.target_lang = target_lang
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask
+ if attention_mask is not None
+ else torch.ones(input_features.shape[:2], device=input_features.device, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum([-1])).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Wav2Vec2Bert adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_bert = Wav2Vec2BertModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_bert.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ hidden_states[~padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Wav2Vec2Bert adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_bert = Wav2Vec2BertModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_bert.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super(AMSoftmaxLoss, self).__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = nn.functional.unfold(
+ hidden_states,
+ (self.kernel_size, self.in_conv_dim),
+ stride=(1, self.in_conv_dim),
+ dilation=(self.dilation, 1),
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.kernel(hidden_states)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_bert = Wav2Vec2BertModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_bert.parameters():
+ param.requires_grad = False
+
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector._get_tdnn_output_lengths
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py
new file mode 100644
index 00000000000000..ec792ce75a0248
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Speech processor class for Wav2Vec2-BERT
+"""
+import warnings
+
+from ...processing_utils import ProcessorMixin
+from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
+from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
+
+
+class Wav2Vec2BertProcessor(ProcessorMixin):
+ r"""
+ Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single
+ processor.
+
+ [`Wav2Vec2Processor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and [`PreTrainedTokenizer`].
+ See the docstring of [`~Wav2Vec2Processor.__call__`] and [`~Wav2Vec2Processor.decode`] for more information.
+
+ Args:
+ feature_extractor (`SeamlessM4TFeatureExtractor`):
+ An instance of [`SeamlessM4TFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer ([`PreTrainedTokenizer`]):
+ An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "SeamlessM4TFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ try:
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
+ except OSError:
+ warnings.warn(
+ f"Loading a tokenizer inside {cls.__name__} from a config that does not"
+ " include a `tokenizer_class` attribute is deprecated and will be "
+ "removed in v5. Please add `'tokenizer_class': 'Wav2Vec2CTCTokenizer'`"
+ " attribute to either your `config.json` or `tokenizer_config.json` "
+ "file to suppress this warning: ",
+ FutureWarning,
+ )
+
+ feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ def __call__(self, audio=None, text=None, **kwargs):
+ """
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio`
+ and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not
+ `None` to pre-process the audio. To prepare the target sequences(s), this method forwards the `text` and `kwargs` arguments to
+ PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
+ of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
+ and T the sample length of the audio.
+ kwargs (*optional*):
+ Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
+ tokenizer.
+ Returns:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+ - **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`.
+ - **attention_mask** -- List of indices specifying which timestamps should be attended to by the model when `audio` is not `None`.
+ When only `text` is specified, returns the token attention mask.
+ - **labels** -- List of token ids to be fed to a model. Returned when both `text` and `audio` are not `None`.
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`.
+ """
+
+ sampling_rate = kwargs.pop("sampling_rate", None)
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def pad(self, input_features=None, labels=None, **kwargs):
+ """
+ If `input_features` is not `None`, this method forwards the `input_features` and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.pad`] to pad the input features.
+ If `labels` is not `None`, this method forwards the `labels` and `kwargs` arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.pad`] to pad the label(s).
+ Please refer to the doctsring of the above two methods for more information.
+ """
+ if input_features is None and labels is None:
+ raise ValueError("You need to specify either an `input_features` or `labels` input to pad.")
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 4d89b2942f7997..345456b7908bff 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -8730,6 +8730,51 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+WAV2VEC2_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class Wav2Vec2BertForAudioFrameClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2BertForCTC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2BertForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2BertForXVector(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2BertModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2BertPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/models/wav2vec2_bert/__init__.py b/tests/models/wav2vec2_bert/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py
new file mode 100644
index 00000000000000..a4a0a95972c9f7
--- /dev/null
+++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py
@@ -0,0 +1,913 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Wav2Vec2-BERT model. """
+import tempfile
+import unittest
+
+from datasets import load_dataset
+
+from transformers import Wav2Vec2BertConfig, is_torch_available
+from transformers.testing_utils import (
+ is_pt_flax_cross_test,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_fp16,
+ slow,
+ torch_device,
+)
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ AutoFeatureExtractor,
+ Wav2Vec2BertForAudioFrameClassification,
+ Wav2Vec2BertForCTC,
+ Wav2Vec2BertForSequenceClassification,
+ Wav2Vec2BertForXVector,
+ Wav2Vec2BertModel,
+ )
+ from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
+ _compute_mask_indices,
+ _sample_negative_indices,
+ )
+
+
+# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerModelTester with Conformer->Bert, input_values->input_features
+class Wav2Vec2BertModelTester:
+ # Ignore copy
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=200, # speech is longer
+ is_training=False,
+ hidden_size=16,
+ feature_projection_input_dim=16,
+ num_conv_pos_embeddings=16,
+ num_conv_pos_embedding_groups=2,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ hidden_dropout_prob=0.1,
+ intermediate_size=20,
+ layer_norm_eps=1e-5,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ mask_time_prob=0.5,
+ mask_time_length=2,
+ vocab_size=32,
+ do_stable_layer_norm=False,
+ num_adapter_layers=2,
+ adapter_stride=2,
+ tdnn_dim=(32, 32),
+ tdnn_kernel=(5, 3),
+ tdnn_dilation=(1, 2),
+ xvector_output_dim=32,
+ position_embeddings_type="relative",
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.feature_projection_input_dim = feature_projection_input_dim
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.intermediate_size = intermediate_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.do_stable_layer_norm = do_stable_layer_norm
+ self.num_adapter_layers = num_adapter_layers
+ self.adapter_stride = adapter_stride
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.scope = scope
+ self.tdnn_dim = tdnn_dim
+ self.tdnn_kernel = tdnn_kernel
+ self.tdnn_dilation = tdnn_dilation
+ self.xvector_output_dim = xvector_output_dim
+ self.position_embeddings_type = position_embeddings_type
+
+ self.output_seq_length = self.seq_length
+ self.encoder_seq_length = self.output_seq_length
+
+ self.adapter_output_seq_length = self.output_seq_length
+
+ for _ in range(num_adapter_layers):
+ self.adapter_output_seq_length = (self.adapter_output_seq_length - 1) // adapter_stride + 1
+
+ # Ignore copy
+ def prepare_config_and_inputs(self, position_embeddings_type="relative"):
+ input_shape = [self.batch_size, self.seq_length, self.feature_projection_input_dim]
+
+ input_features = floats_tensor(input_shape, self.vocab_size)
+ attention_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config(position_embeddings_type=position_embeddings_type)
+
+ return config, input_features, attention_mask
+
+ # Ignore copy
+ def get_config(self, position_embeddings_type="relative"):
+ return Wav2Vec2BertConfig(
+ hidden_size=self.hidden_size,
+ feature_projection_input_dim=self.feature_projection_input_dim,
+ mask_time_prob=self.mask_time_prob,
+ mask_time_length=self.mask_time_length,
+ num_conv_pos_embeddings=self.num_conv_pos_embeddings,
+ num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ intermediate_size=self.intermediate_size,
+ layer_norm_eps=self.layer_norm_eps,
+ do_stable_layer_norm=self.do_stable_layer_norm,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ vocab_size=self.vocab_size,
+ num_adapter_layers=self.num_adapter_layers,
+ adapter_stride=self.adapter_stride,
+ tdnn_dim=self.tdnn_dim,
+ tdnn_kernel=self.tdnn_kernel,
+ tdnn_dilation=self.tdnn_dilation,
+ xvector_output_dim=self.xvector_output_dim,
+ position_embeddings_type=position_embeddings_type,
+ )
+
+ def create_and_check_model(self, config, input_features, attention_mask):
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter_for_ctc(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 2 * config.hidden_size
+ model = Wav2Vec2BertForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
+ )
+
+ # Ignore copy
+ def create_and_check_model_with_intermediate_ffn_before_adapter(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ config.use_intermediate_ffn_before_adapter = True
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
+ )
+
+ # also try with different adapter proj dim
+ config.output_hidden_size = 8
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
+ )
+
+ def create_and_check_model_with_adapter_proj_dim(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 8
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
+ )
+
+ def create_and_check_model_float16(self, config, input_features, attention_mask):
+ model = Wav2Vec2BertModel(config=config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = Wav2Vec2BertModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
+
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ result = model(input_features.type(dtype=torch.float16), attention_mask=attention_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_batch_inference(self, config, input_features, *args):
+ # test does not pass for models making use of `group_norm`
+ # check: https://github.com/pytorch/fairseq/issues/3227
+ model = Wav2Vec2BertModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ input_features = input_features[:3]
+ attention_mask = torch.ones(input_features.shape, device=torch_device, dtype=torch.bool)
+
+ input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0.0
+
+ batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
+
+ for i in range(input_features.shape[0]):
+ input_slice = input_features[i : i + 1, : input_lengths[i]]
+ output = model(input_slice).last_hidden_state
+
+ batch_output = batch_outputs[i : i + 1, : output.shape[1]]
+ self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
+
+ def check_ctc_loss(self, config, input_features, *args):
+ model = Wav2Vec2BertForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_features = input_features[:3]
+ # Ignore copy
+ attention_mask = torch.ones(input_features.shape[:2], device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_seq_classifier_loss(self, config, input_features, *args):
+ model = Wav2Vec2BertForSequenceClassification(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_features = input_features[:3]
+ # Ignore copy
+ attention_mask = torch.ones(input_features.shape[:2], device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ masked_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+ unmasked_loss = model(input_features, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(masked_loss, float))
+ self.parent.assertTrue(isinstance(unmasked_loss, float))
+ self.parent.assertTrue(masked_loss != unmasked_loss)
+
+ def check_ctc_training(self, config, input_features, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2BertForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # Ignore copy
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lengths are at least
+ # one shorter than logit lengths to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_features, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_seq_classifier_training(self, config, input_features, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2BertForSequenceClassification(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_features = input_features[:3]
+
+ # Ignore copy
+ input_lengths = [input_features.shape[1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_features, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_xvector_training(self, config, input_features, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2BertForXVector(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_features.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_features, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_features, *args):
+ model = Wav2Vec2BertForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_features, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_features": input_features, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerModelTest with Conformer->Bert, input_values->input_features
+class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ # Ignore copy
+ all_model_classes = (
+ (
+ Wav2Vec2BertForCTC,
+ Wav2Vec2BertModel,
+ Wav2Vec2BertForSequenceClassification,
+ Wav2Vec2BertForAudioFrameClassification,
+ Wav2Vec2BertForXVector,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ pipeline_model_mapping = (
+ {
+ "audio-classification": Wav2Vec2BertForSequenceClassification,
+ "automatic-speech-recognition": Wav2Vec2BertForCTC,
+ "feature-extraction": Wav2Vec2BertModel,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = Wav2Vec2BertModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Wav2Vec2BertConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_relative(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ # Ignore copy
+ def test_model_with_relative_key(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative_key")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_rotary(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_no_rel_pos(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type=None)
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_adapter(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
+
+ def test_model_with_adapter_for_ctc(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
+
+ # Ignore copy
+ def test_model_with_intermediate_ffn_before_adapter(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_intermediate_ffn_before_adapter(*config_and_inputs)
+
+ def test_model_with_adapter_proj_dim(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
+
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_model_float16_with_relative(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
+ self.model_tester.create_and_check_model_float16(*config_and_inputs)
+
+ # Ignore copy
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_model_float16_with_relative_key(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative_key")
+ self.model_tester.create_and_check_model_float16(*config_and_inputs)
+
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_model_float16_with_rotary(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
+ self.model_tester.create_and_check_model_float16(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_seq_classifier_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_seq_classifier_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_training(*config_and_inputs)
+
+ def test_xvector_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_xvector_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # Ignore copy
+ @unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="`input_ids` is renamed to `input_features`")
+ def test_forward_signature(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="Wav2Vec2Bert has no tokens embeddings")
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
+ def test_model_common_attributes(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="non-robust architecture does not exist in Flax")
+ @is_pt_flax_cross_test
+ def test_equivalence_flax_to_pt(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="non-robust architecture does not exist in Flax")
+ @is_pt_flax_cross_test
+ def test_equivalence_pt_to_flax(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_features = inputs_dict["input_features"]
+
+ input_lengths = torch.tensor(
+ [input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "conv.parametrizations.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "pos_bias_v",
+ "pos_bias_u",
+ "pointwise_conv1",
+ "pointwise_conv2",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any(x in name for x in uniform_init_parms):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
+ module.pos_bias_u.data.fill_(3)
+ if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
+ module.pos_bias_v.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ # Ignore copy
+ @unittest.skip(reason="Kept to make #Copied from working")
+ def test_mask_feature_prob_ctc(self):
+ pass
+
+ # Ignore copy
+ @unittest.skip(reason="Kept to make #Copied from working")
+ def test_mask_time_prob_ctc(self):
+ pass
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ # Ignore copy
+ model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+# Copied from tests.models.wav2vec2_conformer.test_modeling_wav2vec2_conformer.Wav2Vec2ConformerUtilsTest with Conformer->Bert, input_values->input_features
+class Wav2Vec2BertUtilsTest(unittest.TestCase):
+ def test_compute_mask_indices(self):
+ batch_size = 4
+ sequence_length = 60
+ mask_prob = 0.5
+ mask_length = 1
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
+
+ def test_compute_mask_indices_low_prob(self):
+ # with these settings num_masked_spans=0.5, which means probabilistic rounding
+ # ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
+ # the other 5 out of 10, cases num_masked_spans=1
+ n_trials = 100
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ count_dimensions_masked = 0
+ count_dimensions_not_masked = 0
+
+ for _ in range(n_trials):
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ num_masks = torch.sum(mask).item()
+
+ if num_masks > 0:
+ count_dimensions_masked += 1
+ else:
+ count_dimensions_not_masked += 1
+
+ # as we test for at least 10 masked dimension and at least
+ # 10 non-masked dimension, this test could fail with probability:
+ # P(100 coin flips, at most 9 heads) = 1.66e-18
+ self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
+ self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
+
+ def test_compute_mask_indices_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ def test_compute_mask_indices_attn_mask_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ attention_mask[:2, sequence_length // 2 :] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
+ )
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
+
+ def test_compute_mask_indices_short_audio(self):
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ # force one example to be heavily padded
+ attention_mask[0, 5:] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
+ )
+
+ # make sure that non-padded examples cannot be padded
+ self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
+
+ # Ignore copy
+ @unittest.skip(reason="Kept to make #Copied from working. Test a class used for pretraining, not yet supported.")
+ def test_compute_perplexity(self):
+ pass
+
+ def test_sample_negatives(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+ def test_sample_negatives_with_mask(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ # second half of last input tensor is padded
+ mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ mask[-1, sequence_length // 2 :] = 0
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # replace masked feature vectors with -100 to test that those are not sampled
+ features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices(
+ (batch_size, sequence_length), num_negatives, mask.cpu().numpy()
+ )
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+
+ self.assertTrue((negatives >= 0).all().item())
+
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+
+@require_torch
+@slow
+class Wav2Vec2BertModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
+ speech_samples = speech_samples[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_w2v2_bert(self):
+ model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
+ model.to(torch_device)
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = feature_extractor(input_speech, return_tensors="pt", padding=True).to(torch_device)
+
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**inputs, output_attentions=True)
+
+ # fmt: off
+ expected_slice_0 = torch.tensor(
+ [[-0.0098, -0.0570, -0.1286, 0.0439, -0.1037, -0.0235],
+ [-0.0767, 0.0574, -0.3224, 0.0482, 0.0440, -0.0193],
+ [ 0.0220, -0.0878, -0.2027, -0.0028, -0.0666, 0.0721],
+ [ 0.0307, -0.1099, 0.0273, -0.0416, -0.0715, 0.0094],
+ [ 0.0758, -0.0291, 0.1084, 0.0004, -0.0751, -0.0116],
+ [ 0.0349, -0.0343, -0.0098, 0.0415, -0.0617, 0.0241],
+ [-0.0193, -0.0171, 0.1965, 0.0797, -0.0308, 0.2033],
+ [-0.0323, -0.0315, 0.0948, 0.0944, -0.0254, 0.1241],
+ [-0.0493, 0.0010, -0.1762, 0.0034, -0.0787, 0.0832],
+ [ 0.0043, -0.1228, -0.0739, 0.0266, -0.0337, -0.0068]]
+ ).to(torch_device)
+ # fmt: on
+
+ # fmt: off
+ expected_slice_1 = torch.tensor(
+ [[-0.0348, -0.0521, -0.3036, 0.0285, -0.0715, -0.0453],
+ [-0.0102, 0.0114, -0.3266, 0.0027, -0.0558, 0.0038],
+ [ 0.0454, 0.0148, -0.2418, -0.0392, -0.0455, 0.0478],
+ [-0.0013, 0.0825, -0.1730, -0.0091, -0.0426, 0.0360],
+ [-0.0227, 0.0687, -0.1168, 0.0569, -0.0160, 0.0759],
+ [-0.0318, 0.0562, -0.0508, 0.0605, 0.0150, 0.0953],
+ [-0.0415, 0.0438, 0.0233, 0.0336, 0.0262, 0.0860],
+ [-0.0163, 0.0048, 0.0807, 0.0119, 0.0712, 0.0158],
+ [ 0.0244, -0.0145, 0.0262, -0.0237, 0.0283, -0.0125],
+ [-0.0587, -0.0516, -0.0368, -0.0196, 0.0307, -0.1434]]
+ ).to(torch_device)
+ # fmt: on
+
+ self.assertTrue((outputs.last_hidden_state[0, 25:35, 4:10] - expected_slice_0).abs().max() <= 1e-4)
+ self.assertTrue((outputs.last_hidden_state[1, 25:35, 4:10] - expected_slice_1).abs().max() <= 1e-4)
+
+ self.assertAlmostEqual(outputs.last_hidden_state[1].mean().item(), 3.3123e-05)
+ self.assertAlmostEqual(outputs.last_hidden_state[1].std().item(), 0.1545, delta=2e-5)
+
+ self.assertListEqual(list(outputs.last_hidden_state.shape), [2, 326, 1024])
diff --git a/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py
new file mode 100644
index 00000000000000..b6b1506f5e4d68
--- /dev/null
+++ b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py
@@ -0,0 +1,156 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
+from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
+from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
+from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor
+from transformers.utils import FEATURE_EXTRACTOR_NAME
+
+from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
+
+
+# Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor
+class Wav2Vec2BertProcessorTest(unittest.TestCase):
+ def setUp(self):
+ vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+
+ self.add_kwargs_tokens_map = {
+ "pad_token": "",
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ }
+ feature_extractor_map = {
+ "feature_size": 1,
+ "padding_value": 0.0,
+ "sampling_rate": 16000,
+ "return_attention_mask": False,
+ "do_normalize": True,
+ }
+
+ self.tmpdirname = tempfile.mkdtemp()
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+
+ with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(feature_extractor_map) + "\n")
+
+ def get_tokenizer(self, **kwargs_init):
+ kwargs = self.add_kwargs_tokens_map.copy()
+ kwargs.update(kwargs_init)
+ return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return SeamlessM4TFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = Wav2Vec2BertProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = Wav2Vec2BertProcessor(
+ tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()
+ )
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = Wav2Vec2BertProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ raw_speech = floats_list((3, 1000))
+
+ input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
+ input_processor = processor(raw_speech, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "This is a test string"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
+
+ def test_model_input_names(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ self.assertListEqual(
+ processor.model_input_names,
+ feature_extractor.model_input_names,
+ msg="`processor` and `feature_extractor` model input names do not match",
+ )
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 3c466310397946..f63ca3aba92c6e 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -762,6 +762,7 @@
"VitMatteForImageMatting",
"VitsTokenizer",
"VivitModel",
+ "Wav2Vec2BertForCTC",
"Wav2Vec2CTCTokenizer",
"Wav2Vec2Config",
"Wav2Vec2ConformerConfig",
diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt
index 611c515b82ca62..cc754971f08661 100644
--- a/utils/not_doctested.txt
+++ b/utils/not_doctested.txt
@@ -873,6 +873,7 @@ src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to
src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py