Skip to content

Commit

Permalink
Fix Canary chunked infer on short audios (NVIDIA#8382)
Browse files Browse the repository at this point in the history
* add multitaskAED longform infer

Signed-off-by: stevehuang52 <[email protected]>

* refactor and changed default beam=1 and len_pen=0

Signed-off-by: stevehuang52 <[email protected]>

* revert default len_pen=1.0

Signed-off-by: stevehuang52 <[email protected]>

* refactor

Signed-off-by: stevehuang52 <[email protected]>

* refactor

Signed-off-by: stevehuang52 <[email protected]>

* update doc

Signed-off-by: stevehuang52 <[email protected]>

* fix autocast

Signed-off-by: stevehuang52 <[email protected]>

* fix typo in docstring

Signed-off-by: stevehuang52 <[email protected]>

* fix for short segment inference

Signed-off-by: stevehuang52 <[email protected]>

* add rnnt chunk infer utils

Signed-off-by: stevehuang52 <[email protected]>

* add ctc chunked infer

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: stevehuang52 <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
stevehuang52 authored and sashameister committed Feb 15, 2024
1 parent 3ac5f22 commit 6e0a8f6
Showing 1 changed file with 137 additions and 12 deletions.
149 changes: 137 additions & 12 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,13 @@ def update_feature_buffer(self, chunk):


class AudioFeatureIterator(IterableDataset):
def __init__(self, samples, frame_len, preprocessor, device):
def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=True):
self._samples = samples
self._frame_len = frame_len
self._start = 0
self.output = True
self.count = 0
self.pad_to_frame_len = pad_to_frame_len
timestep_duration = preprocessor._cfg['window_stride']
self._feature_frame_len = frame_len / timestep_duration
audio_signal = torch.from_numpy(self._samples).unsqueeze_(0).to(device)
Expand All @@ -492,8 +493,11 @@ def __next__(self):
frame = self._features[:, self._start : last].cpu()
self._start = last
else:
frame = np.zeros([self._features.shape[0], int(self._feature_frame_len)], dtype='float32')
samp_len = self._features_len[0] - self._start
if not self.pad_to_frame_len:
frame = np.zeros([self._features.shape[0], samp_len], dtype='float32')
else:
frame = np.zeros([self._features.shape[0], int(self._feature_frame_len)], dtype='float32')
frame[:, 0:samp_len] = self._features[:, self._start : self._features_len[0]].cpu()
self.output = False
self.count += 1
Expand Down Expand Up @@ -553,7 +557,7 @@ def __next__(self):
self._buf_count += 1
return (
torch.as_tensor(self.signal[self._buf_count - 1], dtype=torch.float32),
torch.as_tensor(self.signal_shape[1], dtype=torch.int64),
torch.as_tensor(self.signal[self._buf_count - 1].shape[1], dtype=torch.int64),
)

def set_signal(self, signals):
Expand All @@ -571,7 +575,7 @@ class FeatureFrameBufferer:
an array of buffers.
"""

def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0):
def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0, pad_to_buffer_len=True):
'''
Args:
frame_len: frame's duration, seconds
Expand All @@ -591,7 +595,7 @@ def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0):
total_buffer_len = int(total_buffer / timestep_duration)
self.n_feat = asr_model._cfg.preprocessor.features
self.buffer = np.ones([self.n_feat, total_buffer_len], dtype=np.float32) * self.ZERO_LEVEL_SPEC_DB_VAL

self.pad_to_buffer_len = pad_to_buffer_len
self.batch_size = batch_size

self.signal_end = False
Expand Down Expand Up @@ -635,9 +639,13 @@ def get_frame_buffers(self, frames):
# Build buffers for each frame
self.frame_buffers = []
for frame in frames:
self.buffer[:, : -self.n_frame_len] = self.buffer[:, self.n_frame_len :]
curr_frame_len = frame.shape[1]
self.buffered_len += curr_frame_len
if curr_frame_len < self.feature_buffer_len and not self.pad_to_buffer_len:
self.frame_buffers.append(np.copy(frame))
continue
self.buffer[:, :-curr_frame_len] = self.buffer[:, curr_frame_len:]
self.buffer[:, -self.n_frame_len :] = frame
self.buffered_len += frame.shape[1]
self.frame_buffers.append(np.copy(self.buffer))
return self.frame_buffers

Expand All @@ -646,8 +654,12 @@ def set_frame_reader(self, frame_reader):
self.signal_end = False

def _update_feature_buffer(self, feat_frame):
self.feature_buffer[:, : -feat_frame.shape[1]] = self.feature_buffer[:, feat_frame.shape[1] :]
self.feature_buffer[:, -feat_frame.shape[1] :] = feat_frame
curr_frame_len = feat_frame.shape[1]
if curr_frame_len < self.feature_buffer_len and not self.pad_to_buffer_len:
self.feature_buffer = np.copy(feat_frame) # assume that only the last frame is less than the buffer length
else:
self.feature_buffer[:, : -feat_frame.shape[1]] = self.feature_buffer[:, feat_frame.shape[1] :]
self.feature_buffer[:, -feat_frame.shape[1] :] = feat_frame
self.buffered_features_size += feat_frame.shape[1]

def get_norm_consts_per_frame(self, batch_frames):
Expand Down Expand Up @@ -689,7 +701,7 @@ class for streaming frame-based ASR use reset() method to reset FrameASR's
"""

def __init__(
self, asr_model, frame_len=1.6, total_buffer=4.0, batch_size=4,
self, asr_model, frame_len=1.6, total_buffer=4.0, batch_size=4, pad_to_buffer_len=True,
):
'''
Args:
Expand All @@ -698,7 +710,11 @@ def __init__(
offset: number of symbols to drop for smooth streaming
'''
self.frame_bufferer = FeatureFrameBufferer(
asr_model=asr_model, frame_len=frame_len, batch_size=batch_size, total_buffer=total_buffer
asr_model=asr_model,
frame_len=frame_len,
batch_size=batch_size,
total_buffer=total_buffer,
pad_to_buffer_len=pad_to_buffer_len,
)

self.asr_model = asr_model
Expand Down Expand Up @@ -1561,6 +1577,9 @@ def get_all_audios(self):


class FrameBatchMultiTaskAED(FrameBatchASR):
def __init__(self, asr_model, frame_len=4, total_buffer=4, batch_size=4):
super().__init__(asr_model, frame_len, total_buffer, batch_size, pad_to_buffer_len=False)

def get_input_tokens(self, sample: dict):
if self.asr_model.prompt_format == "canary":
missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in sample]
Expand All @@ -1584,7 +1603,12 @@ def get_input_tokens(self, sample: dict):

def read_audio_file(self, audio_filepath: str, delay, model_stride_in_secs, meta_data):
self.input_tokens = self.get_input_tokens(meta_data)
super().read_audio_file(audio_filepath, delay, model_stride_in_secs)
samples = get_samples(audio_filepath)
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
frame_reader = AudioFeatureIterator(
samples, self.frame_len, self.raw_preprocessor, self.asr_model.device, pad_to_frame_len=False
)
self.set_frame_reader(frame_reader)

@torch.no_grad()
def _get_batch_preds(self, keep_logits=False):
Expand Down Expand Up @@ -1614,3 +1638,104 @@ def transcribe(

print("keep_logits=True is not supported for MultiTaskAEDFrameBatchInfer. Returning empty logits.")
return hypothesis, []


class FrameBatchChunkedRNNT(FrameBatchASR):
def __init__(self, asr_model, frame_len=4, total_buffer=4, batch_size=4):
super().__init__(asr_model, frame_len, total_buffer, batch_size, pad_to_buffer_len=False)

def read_audio_file(self, audio_filepath: str, delay, model_stride_in_secs):
samples = get_samples(audio_filepath)
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
frame_reader = AudioFeatureIterator(
samples, self.frame_len, self.raw_preprocessor, self.asr_model.device, pad_to_frame_len=False
)
self.set_frame_reader(frame_reader)

@torch.no_grad()
def _get_batch_preds(self, keep_logits=False):
device = self.asr_model.device
for batch in iter(self.data_loader):
feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)

encoded, encoded_len = self.asr_model(
processed_signal=feat_signal, processed_signal_length=feat_signal_len
)

best_hyp_text, all_hyp_text = self.asr_model.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)
self.all_preds.extend(best_hyp_text)
del best_hyp_text
del all_hyp_text
del encoded
del encoded_len

def transcribe(
self, tokens_per_chunk: Optional[int] = None, delay: Optional[int] = None, keep_logits: bool = False
):
"""
unsued params are for keeping the same signature as the parent class
"""
self.infer_logits(keep_logits)

hypothesis = " ".join(self.all_preds)
if not keep_logits:
return hypothesis

print("keep_logits=True is not supported for FrameBatchChunkedRNNT. Returning empty logits.")
return hypothesis, []


class FrameBatchChunkedCTC(FrameBatchASR):
def __init__(self, asr_model, frame_len=4, total_buffer=4, batch_size=4):
super().__init__(asr_model, frame_len, total_buffer, batch_size, pad_to_buffer_len=False)

def read_audio_file(self, audio_filepath: str, delay, model_stride_in_secs):
samples = get_samples(audio_filepath)
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
frame_reader = AudioFeatureIterator(
samples, self.frame_len, self.raw_preprocessor, self.asr_model.device, pad_to_frame_len=False
)
self.set_frame_reader(frame_reader)

@torch.no_grad()
def _get_batch_preds(self, keep_logits=False):
device = self.asr_model.device
for batch in iter(self.data_loader):
feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)

results = self.asr_model(processed_signal=feat_signal, processed_signal_length=feat_signal_len)
if len(results) == 2: # hybrid model
encoded, encoded_len = results
log_probs = self.asr_model.ctc_decoder(encoder_output=encoded)
transcribed_texts, _ = self.asr_model.ctc_decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False,
)
else:
log_probs, encoded_len, predictions = results
transcribed_texts, _ = self.asr_model.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False,
)

self.all_preds.extend(transcribed_texts)
del log_probs
del encoded_len
del predictions

def transcribe(
self, tokens_per_chunk: Optional[int] = None, delay: Optional[int] = None, keep_logits: bool = False
):
"""
unsued params are for keeping the same signature as the parent class
"""
self.infer_logits(keep_logits)

hypothesis = " ".join(self.all_preds)
if not keep_logits:
return hypothesis

print("keep_logits=True is not supported for FrameBatchChunkedCTC. Returning empty logits.")
return hypothesis, []

0 comments on commit 6e0a8f6

Please sign in to comment.