From d57c5b40b06e59ec44240d93485a95799548af50 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Sat, 27 Jul 2024 05:02:58 +0300 Subject: [PATCH] Remove the usage of `transformers.pipeline` from `BatchedInferencePipeline` and fix word timestamps for batched inference (#921) * fix word timestamps for batched inference * remove hf pipeline --- faster_whisper/transcribe.py | 240 ++++++++++------------------------- requirements.txt | 4 +- 2 files changed, 72 insertions(+), 172 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 2934652f..8652e82b 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -15,8 +15,7 @@ import torch from pyannote.audio import Model -from transformers import Pipeline -from transformers.pipelines.pt_utils import PipelineIterator +from tqdm import tqdm from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor @@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple): # (https://github.com/m-bain/whisperX) and adapted for faster_whisper -class BatchedInferencePipeline(Pipeline): +class BatchedInferencePipeline: """ Huggingface Pipeline wrapper for WhisperModel. Copyright (c) 2022, Max Bain @@ -119,55 +118,29 @@ def __init__( use_vad_model: bool = True, options: Optional[NamedTuple] = None, tokenizer=None, - device: Union[int, str, "torch.device"] = -1, chunk_length: int = 30, vad_device: Union[int, str, "torch.device"] = "auto", vad_onset: float = 0.500, vad_offset: float = 0.363, - framework="pt", language: Optional[str] = None, - **kwargs, ): self.model: WhisperModel = model self.tokenizer = tokenizer self.options = options self.preset_language = language - self._batch_size = kwargs.pop("batch_size", None) - self._num_workers = 0 self.use_vad_model = use_vad_model self.vad_onset = vad_onset self.vad_offset = vad_offset self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin") - self.vad_model = None - - ( - self._preprocess_params, - self._forward_params, - self._postprocess_params, - ) = self._sanitize_parameters(**kwargs) - self.call_count = 0 - self.framework = framework - if self.framework == "pt": - self.device = self.get_device(device) - else: - self.device = device - - if self.use_vad_model and self.vad_model is None: + if self.use_vad_model: self.vad_device = self.get_device(vad_device) - - # load vad model and perform VAD preprocessing if needed self.vad_model = self.load_vad_model( vad_onset=self.vad_onset, vad_offset=self.vad_offset ) + else: + self.vad_model = None self.chunk_length = chunk_length # VAD merging size self.last_speech_timestamp = 0.0 - super(Pipeline, self).__init__() - - def _sanitize_parameters(self, **kwargs): - preprocess_kwargs = {} - if "tokenizer" in kwargs: - preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] - return preprocess_kwargs, {}, {} def get_device(self, device: Union[int, str, "torch.device"]): """ @@ -193,27 +166,17 @@ def get_device(self, device: Union[int, str, "torch.device"]): else: return torch.device(f"cuda:{device}") - def preprocess(self, inputs): - audio = inputs["inputs"] - to_cpu = ( - self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 - ) - features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ - :, : self.model.feature_extractor.nb_max_frames - ] - - inputs["features"] = features - del features - return inputs - - def _forward(self, model_inputs, **forward_params): + def forward(self, features, segments_metadata, **forward_params): encoder_output, outputs = self.model.generate_segment_batched( - model_inputs["features"], self.tokenizer, forward_params + features, self.tokenizer, forward_params ) - segment_size = encoder_output.shape[1] * 2 segmented_outputs = [] - for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs): + segment_sizes = [] + for segment_metadata, output in zip(segments_metadata, outputs): + duration = segment_metadata["end_time"] - segment_metadata["start_time"] + segment_size = int(duration * self.model.frames_per_second) + segment_sizes.append(segment_size) ( subsegments, seek, @@ -223,8 +186,7 @@ def _forward(self, model_inputs, **forward_params): tokens=output["tokens"], time_offset=segment_metadata["start_time"], segment_size=segment_size, - segment_duration=segment_metadata["end_time"] - - segment_metadata["start_time"], + segment_duration=duration, seek=0, ) segmented_outputs.append( @@ -248,89 +210,13 @@ def _forward(self, model_inputs, **forward_params): segmented_outputs, self.tokenizer, encoder_output, - segment_size, + segment_sizes, forward_params["prepend_punctuations"], forward_params["append_punctuations"], self.last_speech_timestamp, ) - return {"output": segmented_outputs} - - def __call__(self, inputs, options, batch_size=None, **kwargs): - if batch_size is None: - if self._batch_size is None: - batch_size = 1 - else: - batch_size = self._batch_size - - ( - preprocess_params, - forward_params, - postprocess_params, - ) = self._sanitize_parameters(**kwargs) - - # Fuse __init__ params and __call__ params without modifying the __init__ ones. - preprocess_params = { - **self._preprocess_params, - **preprocess_params, - } - options_dict = options._asdict() - forward_params = {**self._forward_params, **forward_params, **options_dict} - postprocess_params = {**self._postprocess_params, **postprocess_params} - - self.call_count += 1 - if ( - self.call_count > 10 - and self.framework == "pt" - and self.device.type == "cuda" - ): - logging.warning( - "You seem to be using the pipelines sequentially on GPU. Please use a Dataset" - ) - - return self.get_iterator( - inputs, - batch_size, - preprocess_params, - forward_params, - postprocess_params, - ) - - def postprocess(self, model_outputs): - return model_outputs - - def get_iterator( - self, - inputs, - batch_size: int, - preprocess_params=None, - forward_params=None, - postprocess_params=None, - ): - def stack(items): - return { - "inputs": [x["inputs"] for x in items], - "seg_metadata": [x["seg_metadata"] for x in items], - "features": torch.stack([x["features"] for x in items]), - } - - if "TOKENIZERS_PARALLELISM" not in os.environ: - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=self._num_workers, - batch_size=batch_size, - collate_fn=stack, - ) - model_iterator = PipelineIterator( - dataloader, self.forward, forward_params, loader_batch_size=batch_size - ) - final_iterator = PipelineIterator( - model_iterator, self.postprocess, postprocess_params - ) - return final_iterator + return segmented_outputs def get_language_and_tokenizer( self, audio, task: Optional[str] = None, language: Optional[str] = None @@ -369,7 +255,8 @@ def get_language_and_tokenizer( @staticmethod def audio_split(audio, segments, sampling_rate): """Returns splitted audio chunks as iterator""" - + audio_segments = [] + segments_metadata = [] for seg in segments: f1 = int(seg["start"] * sampling_rate) f2 = int(seg["end"] * sampling_rate) @@ -378,7 +265,9 @@ def audio_split(audio, segments, sampling_rate): "end_time": seg["end"], "stitched_seg": seg["segments"], } - yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata} + audio_segments.append(audio[f1:f2]) + segments_metadata.append(seg_metadata) + return audio_segments, segments_metadata def load_vad_model(self, vad_onset=0.500, vad_offset=0.363): vad_model = Model.from_pretrained(self.vad_model_path) @@ -573,7 +462,6 @@ def transcribe( task, all_language_probs, ) = self.get_language_and_tokenizer(audio, task, language) - batch_size = batch_size or self._batch_size duration_after_vad = sum( segment["end"] - segment["start"] for segment in vad_segments @@ -623,10 +511,27 @@ def transcribe( all_language_probs=all_language_probs, ) + audio_segments, segments_metadata = self.audio_split( + audio, vad_segments, sampling_rate + ) + to_cpu = ( + self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 + ) + audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor( + padding=0 + ) + features = torch.stack( + [ + self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[ + ..., : self.model.feature_extractor.nb_max_frames + ] + for audio_segment in audio_segments + ] + ) + segments = self._batched_segments_generator( - audio, - vad_segments, - sampling_rate, + features, + segments_metadata, batch_size, batched_options, log_progress, @@ -635,45 +540,40 @@ def transcribe( return segments, info def _batched_segments_generator( - self, audio, vad_segments, sampling_rate, batch_size, options, log_progress + self, features, segments_metadata, batch_size, options, log_progress ): + pbar = tqdm(total=len(features), disable=not log_progress, position=0) seg_idx = 0 - total_segments = len(vad_segments) - for idx, out in enumerate( - self.__call__( - self.audio_split(audio, vad_segments, sampling_rate), - batch_size=batch_size, - options=options, + for i in range(0, len(features), batch_size): + results = self.forward( + features[i : i + batch_size], + segments_metadata[i : i + batch_size], + **options._asdict(), ) - ): - if log_progress: - percent_complete = ((idx + 1) / total_segments) * 100 - self.model.logger.info(f"Progress: {percent_complete:.2f}%...") - - responses = out["output"] - if batch_size == 1: - responses = responses[0] - - for response in responses: - seg_idx += 1 - segments = Segment( - seek=int(responses[-1]["end"] * self.model.frames_per_second), - id=seg_idx, - text=response["text"], - start=round(response["start"], 3), - end=round(response["end"], 3), - words=( - None - if not options.word_timestamps - else [Word(**word) for word in response["words"]] - ), - tokens=response["tokens"], - avg_logprob=response["avg_logprob"], - no_speech_prob=response["no_speech_prob"], - compression_ratio=response["compression_ratio"], - ) - yield segments + for result in results: + for segment in result: + seg_idx += 1 + yield Segment( + seek=int(result[-1]["end"] * self.model.frames_per_second), + id=seg_idx, + text=segment["text"], + start=round(segment["start"], 3), + end=round(segment["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in segment["words"]] + ), + tokens=segment["tokens"], + avg_logprob=segment["avg_logprob"], + no_speech_prob=segment["no_speech_prob"], + compression_ratio=segment["compression_ratio"], + ) + + pbar.update(1) + + pbar.close() # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None diff --git a/requirements.txt b/requirements.txt index 6516f96c..e0a3afba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ ctranslate2>=4.0,<5 huggingface_hub>=0.13 tokenizers>=0.13,<1 onnxruntime>=1.14,<2 -transformers pyannote-audio>=3.1.1 torch>=2.1.1 -torchaudio>=2.1.2 \ No newline at end of file +torchaudio>=2.1.2 +tqdm \ No newline at end of file