diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 83e57ece59e3..acf437311d50 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -13,11 +13,13 @@ # limitations under the License. import copy import itertools +from collections import Counter from math import ceil from typing import Dict, List, Optional, Union import librosa import numpy as np +import soundfile as sf import torch from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf, open_dict @@ -444,7 +446,7 @@ def infer_file(self, path2audio_file): emb: speaker embeddings (Audio representations) logits: logits corresponding of final layer """ - audio, sr = librosa.load(path2audio_file, sr=None) + audio, sr = sf.read(path2audio_file) target_sr = self._cfg.train_ds.get('sample_rate', 16000) if sr != target_sr: audio = librosa.core.resample(audio, orig_sr=sr, target_sr=target_sr) @@ -452,7 +454,7 @@ def infer_file(self, path2audio_file): device = self.device audio = np.array([audio]) audio_signal, audio_signal_len = ( - torch.tensor(audio, device=device), + torch.tensor(audio, device=device, dtype=torch.float32), torch.tensor([audio_length], device=device), ) mode = self.training @@ -466,25 +468,78 @@ def infer_file(self, path2audio_file): del audio_signal, audio_signal_len return emb, logits - def get_label(self, path2audio_file): + @torch.no_grad() + def infer_segment(self, segment): + """ + Args: + segment: segment of audio file + + Returns: + emb: speaker embeddings (Audio representations) + logits: logits corresponding of final layer + """ + segment_length = segment.shape[0] + + device = self.device + audio = np.array([segment]) + audio_signal, audio_signal_len = ( + torch.tensor(audio, device=device, dtype=torch.float32), + torch.tensor([segment_length], device=device), + ) + mode = self.training + self.freeze() + + logits, emb = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + + self.train(mode=mode) + if mode is True: + self.unfreeze() + del audio_signal, audio_signal_len + return emb, logits + + def get_label( + self, path2audio_file: str, segment_duration: float = np.inf, num_segments: int = 1, random_seed: int = None + ): """ Returns label of path2audio_file from classes the model was trained on. Args: - path2audio_file: path to audio wav file + path2audio_file (str): Path to audio wav file. + segment_duration (float): Random sample duration in seconds. + num_segments (int): Number of segments of file to use for majority vote. + random_seed (int): Seed for generating the starting position of the segment. Returns: label: label corresponding to the trained model """ - _, logits = self.infer_file(path2audio_file=path2audio_file) + audio, sr = sf.read(path2audio_file) + target_sr = self._cfg.train_ds.get('sample_rate', 16000) + if sr != target_sr: + audio = librosa.core.resample(audio, orig_sr=sr, target_sr=target_sr) + audio_length = audio.shape[0] + + duration = target_sr * segment_duration + if duration > audio_length: + duration = audio_length + + label_id_list = [] + np.random.seed(random_seed) + starts = np.random.randint(0, audio_length - duration + 1, size=num_segments) + for start in starts: + audio = audio[start : start + duration] + + _, logits = self.infer_segment(audio) + label_id = logits.argmax(axis=1) + label_id_list.append(int(label_id[0])) + + m_label_id = Counter(label_id_list).most_common(1)[0][0] trained_labels = self._cfg['train_ds'].get('labels', None) if trained_labels is not None: trained_labels = list(trained_labels) - label_id = logits.argmax(axis=1) - label = trained_labels[int(label_id[0])] + label = trained_labels[m_label_id] else: logging.info("labels are not saved to model, hence only outputting the label id index") - label = logits.argmax(axis=1) + label = m_label_id return label