-
Notifications
You must be signed in to change notification settings - Fork 1k
/
transcribe.py
2104 lines (1870 loc) · 86.2 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
import json
import logging
import os
import random
import zlib
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn
import ctranslate2
import numpy as np
import tokenizers
from tqdm import tqdm
from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
VadOptions,
collect_chunks,
get_speech_timestamps,
merge_segments,
)
@dataclass
class Word:
start: float
end: float
word: str
probability: float
def _asdict(self):
warn(
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
DeprecationWarning,
2,
)
return asdict(self)
@dataclass
class Segment:
id: int
seek: int
start: float
end: float
text: str
tokens: List[int]
avg_logprob: float
compression_ratio: float
no_speech_prob: float
words: Optional[List[Word]]
temperature: Optional[float] = 1.0
def _asdict(self):
warn(
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
DeprecationWarning,
2,
)
return asdict(self)
# Added additional parameters for multilingual videos and fixes below
@dataclass
class TranscriptionOptions:
beam_size: int
best_of: int
patience: float
length_penalty: float
repetition_penalty: float
no_repeat_ngram_size: int
log_prob_threshold: Optional[float]
log_prob_low_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
prompt_reset_on_temperature: float
temperatures: List[float]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
without_timestamps: bool
max_initial_timestamp: float
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str
multilingual: bool
output_language: Optional[str]
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
hotwords: Optional[str]
@dataclass
class TranscriptionInfo:
language: str
language_probability: float
duration: float
duration_after_vad: float
all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions
vad_options: VadOptions
# The code below is originally from HF pipeline and is used in whisper-x
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
class BatchedInferencePipeline:
"""
Huggingface Pipeline wrapper for WhisperModel.
Copyright (c) 2022, Max Bain
All rights reserved.
Modified by Mobius Labs GmbH
"""
def __init__(
self,
model,
options: Optional[TranscriptionOptions] = None,
tokenizer=None,
language: Optional[str] = None,
):
self.model: WhisperModel = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.last_speech_timestamp = 0.0
def forward(self, features, chunks_metadata, **forward_params):
encoder_output, outputs = self.model.generate_segment_batched(
features, self.tokenizer, forward_params
)
segmented_outputs = []
segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs):
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
segment_size = int(ceil(duration) * self.model.frames_per_second)
segment_sizes.append(segment_size)
(
subsegments,
seek,
single_timestamp_ending,
) = self.model._split_segments_by_timestamps(
tokenizer=self.tokenizer,
tokens=output["tokens"],
time_offset=chunk_metadata["start_time"],
segment_size=segment_size,
segment_duration=duration,
seek=0,
)
segmented_outputs.append(
[
dict(
text=self.tokenizer.decode(subsegment["tokens"]),
avg_logprob=output["avg_logprob"],
no_speech_prob=output["no_speech_prob"],
tokens=subsegment["tokens"],
start=subsegment["start"],
end=subsegment["end"],
compression_ratio=get_compression_ratio(
self.tokenizer.decode(subsegment["tokens"])
),
seek=int(
chunk_metadata["start_time"] * self.model.frames_per_second
),
)
for subsegment in subsegments
]
)
if forward_params["word_timestamps"]:
self.last_speech_timestamp = self.model.add_word_timestamps(
segmented_outputs,
self.tokenizer,
encoder_output,
segment_sizes,
forward_params["prepend_punctuations"],
forward_params["append_punctuations"],
self.last_speech_timestamp,
)
return segmented_outputs
def get_language_and_tokenizer(
self, audio, task: Optional[str] = None, language: Optional[str] = None
):
all_language_probs = None
language_probability = 1.0
if self.tokenizer is None:
if not language:
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
if task is not None:
self.tokenizer.task = self.tokenizer.tokenizer.token_to_id(
f"<|{task}|>"
)
if language is not None:
self.tokenizer.language = self.tokenizer.tokenizer.token_to_id(
f"<|{language}|>"
)
self.tokenizer.language_code = language
return language, language_probability, task, all_language_probs
def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
language: Optional[str] = None,
task: str = None,
log_progress: bool = False,
beam_size: int = 5,
best_of: int = 5,
patience: float = 1,
length_penalty: float = 1,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
0.4,
0.6,
0.8,
1.0,
],
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = 0.6,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = True,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hotwords: Optional[str] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Arguments:
audio: Path to the input file (or a file-like object), or the audio waveform.
language: The language spoken in the audio. It should be a language code such
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
of audio.
task: Task to execute (transcribe or translate).
log_progress: whether to show progress bar or not.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
without_timestamps: Only sample text tokens.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
Set as False.
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Optionally provide list of dictionaries each containing "start" and
"end" keys that specify the start and end of the voiced region within
`chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
batch_size: the maximum number of parallel requests to model for decoding.
hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. Set as False
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True. Set at 0.5
#TODO: support "hallucination_silence_threshold" when "word_timestamps=True"
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
unused:
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.model.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
chunk_length = chunk_length or self.model.feature_extractor.chunk_length
# if no segment split is provided, use vad_model and generate segments
if not clip_timestamps:
if vad_filter:
if vad_parameters is None:
vad_parameters = VadOptions(
max_speech_duration_s=chunk_length,
min_silence_duration_ms=160,
)
elif isinstance(vad_parameters, dict):
if "max_speech_duration_s" in vad_parameters.keys():
vad_parameters.pop("max_speech_duration_s")
vad_parameters = VadOptions(
**vad_parameters, max_speech_duration_s=chunk_length
)
active_segments = get_speech_timestamps(audio, vad_parameters)
clip_timestamps = merge_segments(active_segments, vad_parameters)
# run the audio if it is less than 30 sec even without clip_timestamps
elif duration < chunk_length:
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
else:
raise RuntimeError(
"No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
if self.model.model.is_multilingual:
language = language or self.preset_language
elif language != "en":
if language is not None:
self.model.logger.warning(
f"English-only model is used, but {language} language is"
" chosen, setting language to 'en'."
)
language = "en"
(
language,
language_probability,
task,
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
/ sampling_rate
)
# batched options: see the difference with default options in WhisperModel
batched_options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
log_prob_low_threshold=log_prob_low_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens),
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
hotwords=hotwords,
word_timestamps=word_timestamps,
hallucination_silence_threshold=None,
condition_on_previous_text=False,
clip_timestamps="0",
prompt_reset_on_temperature=0.5,
multilingual=False,
output_language=None,
without_timestamps=without_timestamps,
max_initial_timestamp=0.0,
)
info = TranscriptionInfo(
language=language,
language_probability=language_probability,
duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=batched_options,
vad_options=None,
all_language_probs=all_language_probs,
)
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
np.stack(
[
pad_or_trim(
self.model.feature_extractor(chunk)[
...,
: chunk.shape[0] // self.model.feature_extractor.hop_length,
]
)
for chunk in audio_chunks
]
)
if duration_after_vad
else []
)
segments = self._batched_segments_generator(
features,
chunks_metadata,
batch_size,
batched_options,
log_progress,
)
return segments, info
def _batched_segments_generator(
self, features, chunks_metadata, batch_size, options, log_progress
):
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
seg_idx = 0
for i in range(0, len(features), batch_size):
results = self.forward(
features[i : i + batch_size],
chunks_metadata[i : i + batch_size],
**asdict(options),
)
for result in results:
for segment in result:
seg_idx += 1
yield Segment(
seek=segment["seek"],
id=seg_idx,
text=segment["text"],
start=round(segment["start"], 3),
end=round(segment["end"], 3),
words=(
None
if not options.word_timestamps
else [Word(**word) for word in segment["words"]]
),
tokens=segment["tokens"],
avg_logprob=segment["avg_logprob"],
no_speech_prob=segment["no_speech_prob"],
compression_ratio=segment["compression_ratio"],
)
pbar.update(1)
pbar.close()
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
self.last_speech_timestamp = 0.0
class WhisperModel:
def __init__(
self,
model_size_or_path: str,
device: str = "auto",
device_index: Union[int, List[int]] = 0,
compute_type: str = "default",
cpu_threads: int = 0,
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
**model_kwargs,
):
"""Initializes the Whisper model.
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo),
a path to a converted model directory, or a CTranslate2-converted Whisper model ID from
the HF Hub. When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
The model can also be loaded on multiple GPUs by passing a list of IDs
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
when transcribe() is called from multiple Python threads (see also num_workers).
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
A non zero value overrides the OMP_NUM_THREADS environment variable.
num_workers: When transcribe() is called from multiple Python threads,
having multiple workers enables true parallelism when running the model
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
download_root: Directory where the models should be saved. If not set, the models
are saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the
local cached file if it exists.
files: Load model files from the memory. This argument is a dictionary mapping file names
to file contents as file-like or bytes objects. If this is set, model_path acts as an
identifier for this model.
"""
self.logger = get_logger()
tokenizer_bytes, preprocessor_bytes = None, None
if files:
model_path = model_size_or_path
tokenizer_bytes = files.pop("tokenizer.json", None)
preprocessor_bytes = files.pop("preprocessor_config.json", None)
elif os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(
model_size_or_path,
local_files_only=local_files_only,
cache_dir=download_root,
)
self.device = device
# set the random seed to make sure consistency across runs
ctranslate2.set_random_seed(42)
self.model = ctranslate2.models.Whisper(
model_path,
device=self.device,
device_index=device_index,
compute_type=compute_type,
intra_threads=cpu_threads,
inter_threads=num_workers,
files=files,
**model_kwargs,
)
tokenizer_file = os.path.join(model_path, "tokenizer.json")
if tokenizer_bytes:
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
elif os.path.isfile(tokenizer_file):
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else:
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.input_stride = 2
self.num_samples_per_token = (
self.feature_extractor.hop_length * self.input_stride
)
self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
)
self.tokens_per_second = (
self.feature_extractor.sampling_rate // self.num_samples_per_token
)
self.time_precision = 0.02
self.max_length = 448
@property
def supported_languages(self) -> List[str]:
"""The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
config = {}
try:
config_path = os.path.join(model_path, "preprocessor_config.json")
if preprocessor_bytes:
config = json.loads(preprocessor_bytes)
elif os.path.isfile(config_path):
with open(config_path, "r", encoding="utf-8") as file:
config = json.load(file)
else:
return config
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
return {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning("Could not load preprocessor config: %s", e)
return config
def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
language: Optional[str] = None,
task: str = "transcribe",
log_progress: bool = False,
beam_size: int = 5,
best_of: int = 5,
patience: float = 1,
length_penalty: float = 1,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
0.4,
0.6,
0.8,
1.0,
],
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = False,
max_initial_timestamp: float = 1.0,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Arguments:
audio: Path to the input file (or a file-like object), or the audio waveform.
language: The language spoken in the audio. It should be a language code such
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
of audio.
task: Task to execute (transcribe or translate).
log_progress: whether to show progress bar or not.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
wheras log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
without_timestamps: Only sample text tokens.
max_initial_timestamp: The initial timestamp cannot be later than this.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: If True, perform transcription on multilingual videos
and return the transcript based
on the 'output_language' flag.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription).
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps:
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file.
vad_filter will be ignored if clip_timestamps is used.
hallucination_silence_threshold:
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
hotwords:
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
duration_after_vad = duration
self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)
if vad_filter and clip_timestamps == "0":
if vad_parameters is None:
vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
duration_after_vad = audio.shape[0] / sampling_rate
self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad),
)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)
else:
speech_chunks = None
features = self.feature_extractor(audio, chunk_length=chunk_length)
encoder_output = None
all_language_probs = None
# setting output_language for multilingual videos
if multilingual:
if output_language is None:
output_language = "en"
elif output_language not in ["en", "hybrid"]:
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")
# detecting the language if not provided
if language is None:
if not self.model.is_multilingual:
language = "en"
language_probability = 1
else:
if (
language_detection_segments is None
or language_detection_segments < 1
):
language_detection_segments = 1
start_timestamp = (
float(clip_timestamps.split(",")[0])
if isinstance(clip_timestamps, str)
else clip_timestamps[0]
)
content_frames = features.shape[-1] - 1
seek = (
int(start_timestamp * self.frames_per_second)
if start_timestamp * self.frames_per_second < content_frames
else 0
)
end_frames = min(
seek
+ self.feature_extractor.nb_max_frames
* language_detection_segments,
content_frames,
)
detected_language_info = {}
while seek <= end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(pad_or_trim(segment))
# results is a list of tuple[str, float] with language names and
# probabilities.
results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [
(token[2:-2], prob) for (token, prob) in results
]
# Get top language token and probability
language, language_probability = all_language_probs[0]
if language_probability > language_detection_threshold:
break
detected_language_info.setdefault(language, []).append(
language_probability
)
seek += segment.shape[-1]
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
self.logger.info(
"Detected language '%s' with probability %.2f",
language,
language_probability,
)
else:
if not self.model.is_multilingual and language != "en":
self.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"
language_probability = 1
tokenizer = Tokenizer(
self.hf_tokenizer,
self.model.is_multilingual,
task=task,
language=language,
)
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
log_prob_low_threshold=log_prob_low_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=(
get_suppressed_tokens(tokenizer, suppress_tokens)
if suppress_tokens
else suppress_tokens
),
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
multilingual=multilingual,
output_language=output_language,
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
hotwords=hotwords,
)
segments = self.generate_segments(
features, tokenizer, options, log_progress, encoder_output
)
if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
info = TranscriptionInfo(
language=language,
language_probability=language_probability,
duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=options,
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, info
def _split_segments_by_timestamps(
self,
tokenizer: Tokenizer,
tokens: List[int],
time_offset: float,
segment_size: int,
segment_duration: float,
seek: int,
) -> List[List[int]]:
current_segments = []
single_timestamp_ending = (
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
)
consecutive_timestamps = [
i
for i in range(len(tokens))
if i > 0
and tokens[i] >= tokenizer.timestamp_begin
and tokens[i - 1] >= tokenizer.timestamp_begin
]
if len(consecutive_timestamps) > 0:
slices = list(consecutive_timestamps)
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin
end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin
start_time = (
time_offset + start_timestamp_position * self.time_precision