Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LID: several random samples for long file #6853

Merged
merged 32 commits into from
Nov 6, 2023
Merged
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
44fbcd8
add duration_limit
karpnv Jun 7, 2023
3a2303a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Jun 12, 2023
034bb3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2023
ebd9f1b
target_sr
karpnv Jun 14, 2023
30cee72
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Jun 14, 2023
5c1034d
limit first
karpnv Jul 6, 2023
facb08d
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Jul 6, 2023
c6eae3e
soundfile
karpnv Jul 6, 2023
d936bc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
54da5d8
rm soudfile
karpnv Jul 10, 2023
06ff20e
Merge branch 'karpnv/duration_limit' of https://github.com/NVIDIA/NeM…
karpnv Jul 10, 2023
541ae3a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Jul 10, 2023
4125955
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Sep 29, 2023
20e67ab
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Oct 24, 2023
514f66d
infer_segment
karpnv Oct 24, 2023
270789f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2023
68bd103
soundfile
karpnv Oct 24, 2023
c41b09f
Merge branch 'karpnv/duration_limit' of https://github.com/NVIDIA/NeM…
karpnv Oct 24, 2023
4a0acc6
docstring
karpnv Oct 24, 2023
961893d
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Oct 25, 2023
360fcf4
soundfile
karpnv Oct 25, 2023
9b15855
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2023
26bcfa7
type float
karpnv Oct 25, 2023
93ecb72
Merge branch 'karpnv/duration_limit' of https://github.com/NVIDIA/NeM…
karpnv Oct 25, 2023
783121a
random_seed
karpnv Oct 26, 2023
5000160
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Oct 26, 2023
476dec3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2023
6824fce
Merge branch 'main' into karpnv/duration_limit
karpnv Oct 27, 2023
58ffd07
Merge branch 'main' into karpnv/duration_limit
karpnv Oct 31, 2023
7619746
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Nov 2, 2023
bae985e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into karpnv/dur…
karpnv Nov 3, 2023
3df2df1
to float
karpnv Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
import copy
import itertools
import random
from collections import Counter
from math import ceil
from typing import Dict, List, Optional, Union

import librosa
import numpy as np
import soundfile as sf
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
Expand Down Expand Up @@ -444,7 +447,7 @@ def infer_file(self, path2audio_file):
emb: speaker embeddings (Audio representations)
logits: logits corresponding of final layer
"""
audio, sr = librosa.load(path2audio_file, sr=None)
audio, sr = sf.read(path2audio_file)
target_sr = self._cfg.train_ds.get('sample_rate', 16000)
if sr != target_sr:
audio = librosa.core.resample(audio, orig_sr=sr, target_sr=target_sr)
Expand All @@ -466,25 +469,74 @@ def infer_file(self, path2audio_file):
del audio_signal, audio_signal_len
return emb, logits

def get_label(self, path2audio_file):
@torch.no_grad()
def infer_segment(self, segment):
"""
Args:
segment: segment of audio file

Returns:
emb: speaker embeddings (Audio representations)
logits: logits corresponding of final layer
"""
segment_length = segment.shape[0]

device = self.device
audio = np.array([segment])
audio_signal, audio_signal_len = (
torch.tensor(audio, device=device, dtype=torch.float32),
torch.tensor([segment_length], device=device),
)
mode = self.training
self.freeze()

logits, emb = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)

self.train(mode=mode)
if mode is True:
self.unfreeze()
del audio_signal, audio_signal_len
return emb, logits

def get_label(self, path2audio_file: str, segment_duration: float = np.inf, num_segments: int = 1):
"""
Returns label of path2audio_file from classes the model was trained on.
Args:
path2audio_file: path to audio wav file
path2audio_file (str): path to audio wav file
segment_duration (float): random sample duration in seconds
num_segments (int): number of segments of file to use for majority vote
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of num_segments, just do non-overlap segments from start to end based on 5 sec audio samples? Have you done ablation study on what is best?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally didn't, but it was suggested by Fai. This is for very long audio (several hours). We take several segments and get result by majority vote


Returns:
label: label corresponding to the trained model
"""
_, logits = self.infer_file(path2audio_file=path2audio_file)
audio, sr = sf.read(path2audio_file)
target_sr = self._cfg.train_ds.get('sample_rate', 16000)
if sr != target_sr:
audio = librosa.core.resample(audio, orig_sr=sr, target_sr=target_sr)
audio_length = audio.shape[0]

duration = target_sr * segment_duration
if duration > audio_length:
duration = audio_length

label_id_list = []
for j in range(0, num_segments):
start = random.randint(0, audio_length - duration)
audio = audio[start : start + duration]

_, logits = self.infer_segment(audio)
label_id = logits.argmax(axis=1)
label_id_list.append(int(label_id[0]))

m_label_id = Counter(label_id_list).most_common(1)[0][0]

trained_labels = self._cfg['train_ds'].get('labels', None)
if trained_labels is not None:
trained_labels = list(trained_labels)
label_id = logits.argmax(axis=1)
label = trained_labels[int(label_id[0])]
label = trained_labels[m_label_id]
else:
logging.info("labels are not saved to model, hence only outputting the label id index")
label = logits.argmax(axis=1)
label = m_label_id

return label

Expand Down
Loading