Skip to content

Commit

Permalink
Remove the usage of transformers.pipeline from `BatchedInferencePip…
Browse files Browse the repository at this point in the history
…eline` and fix word timestamps for batched inference (#921)

* fix word timestamps for batched inference

* remove hf pipeline
  • Loading branch information
MahmoudAshraf97 authored Jul 27, 2024
1 parent 83a368e commit d57c5b4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 172 deletions.
240 changes: 70 additions & 170 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import torch

from pyannote.audio import Model
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from tqdm import tqdm

from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
Expand Down Expand Up @@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple):
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper


class BatchedInferencePipeline(Pipeline):
class BatchedInferencePipeline:
"""
Huggingface Pipeline wrapper for WhisperModel.
Copyright (c) 2022, Max Bain
Expand All @@ -119,55 +118,29 @@ def __init__(
use_vad_model: bool = True,
options: Optional[NamedTuple] = None,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
chunk_length: int = 30,
vad_device: Union[int, str, "torch.device"] = "auto",
vad_onset: float = 0.500,
vad_offset: float = 0.363,
framework="pt",
language: Optional[str] = None,
**kwargs,
):
self.model: WhisperModel = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 0
self.use_vad_model = use_vad_model
self.vad_onset = vad_onset
self.vad_offset = vad_offset
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
self.vad_model = None

(
self._preprocess_params,
self._forward_params,
self._postprocess_params,
) = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
self.device = self.get_device(device)
else:
self.device = device

if self.use_vad_model and self.vad_model is None:
if self.use_vad_model:
self.vad_device = self.get_device(vad_device)

# load vad model and perform VAD preprocessing if needed
self.vad_model = self.load_vad_model(
vad_onset=self.vad_onset, vad_offset=self.vad_offset
)
else:
self.vad_model = None
self.chunk_length = chunk_length # VAD merging size
self.last_speech_timestamp = 0.0
super(Pipeline, self).__init__()

def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}

def get_device(self, device: Union[int, str, "torch.device"]):
"""
Expand All @@ -193,27 +166,17 @@ def get_device(self, device: Union[int, str, "torch.device"]):
else:
return torch.device(f"cuda:{device}")

def preprocess(self, inputs):
audio = inputs["inputs"]
to_cpu = (
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
)
features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
:, : self.model.feature_extractor.nb_max_frames
]

inputs["features"] = features
del features
return inputs

def _forward(self, model_inputs, **forward_params):
def forward(self, features, segments_metadata, **forward_params):
encoder_output, outputs = self.model.generate_segment_batched(
model_inputs["features"], self.tokenizer, forward_params
features, self.tokenizer, forward_params
)

segment_size = encoder_output.shape[1] * 2
segmented_outputs = []
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
segment_sizes = []
for segment_metadata, output in zip(segments_metadata, outputs):
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
segment_size = int(duration * self.model.frames_per_second)
segment_sizes.append(segment_size)
(
subsegments,
seek,
Expand All @@ -223,8 +186,7 @@ def _forward(self, model_inputs, **forward_params):
tokens=output["tokens"],
time_offset=segment_metadata["start_time"],
segment_size=segment_size,
segment_duration=segment_metadata["end_time"]
- segment_metadata["start_time"],
segment_duration=duration,
seek=0,
)
segmented_outputs.append(
Expand All @@ -248,89 +210,13 @@ def _forward(self, model_inputs, **forward_params):
segmented_outputs,
self.tokenizer,
encoder_output,
segment_size,
segment_sizes,
forward_params["prepend_punctuations"],
forward_params["append_punctuations"],
self.last_speech_timestamp,
)

return {"output": segmented_outputs}

def __call__(self, inputs, options, batch_size=None, **kwargs):
if batch_size is None:
if self._batch_size is None:
batch_size = 1
else:
batch_size = self._batch_size

(
preprocess_params,
forward_params,
postprocess_params,
) = self._sanitize_parameters(**kwargs)

# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self._preprocess_params,
**preprocess_params,
}
options_dict = options._asdict()
forward_params = {**self._forward_params, **forward_params, **options_dict}
postprocess_params = {**self._postprocess_params, **postprocess_params}

self.call_count += 1
if (
self.call_count > 10
and self.framework == "pt"
and self.device.type == "cuda"
):
logging.warning(
"You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
)

return self.get_iterator(
inputs,
batch_size,
preprocess_params,
forward_params,
postprocess_params,
)

def postprocess(self, model_outputs):
return model_outputs

def get_iterator(
self,
inputs,
batch_size: int,
preprocess_params=None,
forward_params=None,
postprocess_params=None,
):
def stack(items):
return {
"inputs": [x["inputs"] for x in items],
"seg_metadata": [x["seg_metadata"] for x in items],
"features": torch.stack([x["features"] for x in items]),
}

if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self._num_workers,
batch_size=batch_size,
collate_fn=stack,
)
model_iterator = PipelineIterator(
dataloader, self.forward, forward_params, loader_batch_size=batch_size
)
final_iterator = PipelineIterator(
model_iterator, self.postprocess, postprocess_params
)
return final_iterator
return segmented_outputs

def get_language_and_tokenizer(
self, audio, task: Optional[str] = None, language: Optional[str] = None
Expand Down Expand Up @@ -369,7 +255,8 @@ def get_language_and_tokenizer(
@staticmethod
def audio_split(audio, segments, sampling_rate):
"""Returns splitted audio chunks as iterator"""

audio_segments = []
segments_metadata = []
for seg in segments:
f1 = int(seg["start"] * sampling_rate)
f2 = int(seg["end"] * sampling_rate)
Expand All @@ -378,7 +265,9 @@ def audio_split(audio, segments, sampling_rate):
"end_time": seg["end"],
"stitched_seg": seg["segments"],
}
yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata}
audio_segments.append(audio[f1:f2])
segments_metadata.append(seg_metadata)
return audio_segments, segments_metadata

def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
vad_model = Model.from_pretrained(self.vad_model_path)
Expand Down Expand Up @@ -573,7 +462,6 @@ def transcribe(
task,
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
batch_size = batch_size or self._batch_size

duration_after_vad = sum(
segment["end"] - segment["start"] for segment in vad_segments
Expand Down Expand Up @@ -623,10 +511,27 @@ def transcribe(
all_language_probs=all_language_probs,
)

audio_segments, segments_metadata = self.audio_split(
audio, vad_segments, sampling_rate
)
to_cpu = (
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
)
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
padding=0
)
features = torch.stack(
[
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
..., : self.model.feature_extractor.nb_max_frames
]
for audio_segment in audio_segments
]
)

segments = self._batched_segments_generator(
audio,
vad_segments,
sampling_rate,
features,
segments_metadata,
batch_size,
batched_options,
log_progress,
Expand All @@ -635,45 +540,40 @@ def transcribe(
return segments, info

def _batched_segments_generator(
self, audio, vad_segments, sampling_rate, batch_size, options, log_progress
self, features, segments_metadata, batch_size, options, log_progress
):
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
seg_idx = 0
total_segments = len(vad_segments)
for idx, out in enumerate(
self.__call__(
self.audio_split(audio, vad_segments, sampling_rate),
batch_size=batch_size,
options=options,
for i in range(0, len(features), batch_size):
results = self.forward(
features[i : i + batch_size],
segments_metadata[i : i + batch_size],
**options._asdict(),
)
):
if log_progress:
percent_complete = ((idx + 1) / total_segments) * 100
self.model.logger.info(f"Progress: {percent_complete:.2f}%...")

responses = out["output"]
if batch_size == 1:
responses = responses[0]

for response in responses:
seg_idx += 1
segments = Segment(
seek=int(responses[-1]["end"] * self.model.frames_per_second),
id=seg_idx,
text=response["text"],
start=round(response["start"], 3),
end=round(response["end"], 3),
words=(
None
if not options.word_timestamps
else [Word(**word) for word in response["words"]]
),
tokens=response["tokens"],
avg_logprob=response["avg_logprob"],
no_speech_prob=response["no_speech_prob"],
compression_ratio=response["compression_ratio"],
)
yield segments

for result in results:
for segment in result:
seg_idx += 1
yield Segment(
seek=int(result[-1]["end"] * self.model.frames_per_second),
id=seg_idx,
text=segment["text"],
start=round(segment["start"], 3),
end=round(segment["end"], 3),
words=(
None
if not options.word_timestamps
else [Word(**word) for word in segment["words"]]
),
tokens=segment["tokens"],
avg_logprob=segment["avg_logprob"],
no_speech_prob=segment["no_speech_prob"],
compression_ratio=segment["compression_ratio"],
)

pbar.update(1)

pbar.close()
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
onnxruntime>=1.14,<2
transformers
pyannote-audio>=3.1.1
torch>=2.1.1
torchaudio>=2.1.2
torchaudio>=2.1.2
tqdm

0 comments on commit d57c5b4

Please sign in to comment.