Skip to content

Commit

Permalink
Add data augmentation (#11)
Browse files Browse the repository at this point in the history
* add data aug, inference broken

* add more aug

* bf16

* fix data aug - working
  • Loading branch information
loubbrad committed Mar 1, 2024
1 parent 4d12799 commit f6f5fbb
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 68 deletions.
195 changes: 194 additions & 1 deletion amt/audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Contains code taken from https://github.com/openai/whisper"""

import os
import random
import torch
import numpy as np
import torchaudio
import torch.nn.functional as F
import torchaudio.functional as AF
import numpy as np

from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

from amt.config import load_config
from amt.tokenizer import AmtTokenizer

# hard-coded audio hyperparameters
config = load_config()["audio"]
Expand Down Expand Up @@ -176,3 +180,192 @@ def log_mel_spectrogram(
log_spec = (log_spec + 4.0) / 4.0

return log_spec


class AudioTransform(torch.nn.Module):
def __init__(
self,
reverb_factor: int = 1,
min_snr: int = 10,
max_snr: int = 40,
):
super().__init__()
self.tokenizer = AmtTokenizer()
self.reverb_factor = reverb_factor
self.min_snr = min_snr
self.max_snr = max_snr

self.config = load_config()["audio"]
self.sample_rate = self.config["sample_rate"]
self.chunk_len = self.config["chunk_len"]
self.num_samples = self.sample_rate * self.chunk_len

# Audio aug
impulse_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "impulse")
)
noise_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "noise")
)

# Register impulses and noises as buffers
self.num_impulse = 0
for i, impulse in enumerate(self._get_impulses(impulse_paths)):
self.register_buffer(f"impulse_{i}", impulse)
self.num_impulse += 1

self.num_noise = 0
for i, noise in enumerate(self._get_noise(noise_paths)):
self.register_buffer(f"noise_{i}", noise)
self.num_noise += 1

self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.config["n_fft"],
hop_length=self.config["hop_len"],
)
self.mel_transform = torchaudio.transforms.MelScale(
n_mels=self.config["n_mels"],
sample_rate=self.config["sample_rate"],
n_stft=self.config["n_fft"] // 2 + 1,
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=15, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=500, iid_masks=True
),
)

def _get_paths(self, dir_path):
return [
os.path.join(dir_path, f)
for f in os.listdir(dir_path)
if os.path.isfile(os.path.join(dir_path, f))
]

def _get_impulses(self, impulse_paths: list):
impulses = [torchaudio.load(path) for path in impulse_paths]
impulses = [
AF.resample(
waveform=wav, orig_freq=sr, new_freq=config["sample_rate"]
).mean(0, keepdim=True)[:, : 5 * self.sample_rate]
for wav, sr in impulses
]
return [
(wav) / (torch.linalg.vector_norm(wav, ord=2)) for wav in impulses
]

def _get_noise(self, noise_paths: list):
noises = [torchaudio.load(path) for path in noise_paths]
noises = [
AF.resample(
waveform=wav, orig_freq=sr, new_freq=config["sample_rate"]
).mean(0, keepdim=True)[:, : self.num_samples]
for wav, sr in noises
]

for wav in noises:
assert wav.shape[-1] == self.num_samples, "noise wav too short"

return noises

def apply_reverb(self, wav: torch.Tensor):
# wav: (bz, L)
batch_size, _ = wav.shape

reverb_strength = (
torch.Tensor([random.uniform(0, 1) for _ in range(batch_size)])
.unsqueeze(-1)
.to(wav.device)
)
reverb_type = random.randint(0, self.num_impulse - 1)
impulse = getattr(self, f"impulse_{reverb_type}")

reverb = AF.fftconvolve(wav, impulse, mode="full")[
:, : self.num_samples
]
if self.reverb_factor > 1:
for _ in range(self.reverb_factor - 1):
reverb = AF.fftconvolve(reverb, impulse, mode="full")[
: self.num_samples
]

res = (reverb_strength * reverb) + ((1 - reverb_strength) * wav)

return res

def apply_noise(self, wav: torch.tensor):
batch_size, _ = wav.shape

snr_dbs = torch.tensor(
[
random.randint(self.min_snr, self.max_snr)
for _ in range(batch_size)
]
).to(wav.device)
noise_type = random.randint(0, self.num_noise - 1)
noise = getattr(self, f"noise_{noise_type}")

return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)

def shift_spec(self, specs: torch.Tensor, shift: int):
if shift == 0:
return specs

freq_mult = 2 ** (shift / 12.0)
_, num_bins, L = specs.shape
new_num_bins = int(num_bins * freq_mult)

# Interpolate expects extra channel dim
specs = specs.unsqueeze(1)
shifted_specs = torch.nn.functional.interpolate(
specs, size=(new_num_bins, L), mode="bilinear", align_corners=False
)
shifted_specs = shifted_specs.squeeze(1)

if shift > 0:
shifted_specs = shifted_specs[:, :num_bins, :]
else:
padding = num_bins - shifted_specs.size(1)
shifted_specs = torch.nn.functional.pad(
shifted_specs, (0, 0, 0, padding), "constant", 0
)

return shifted_specs

def aug_wav(self, wav: torch.Tensor):
return self.apply_reverb(self.apply_noise(wav))

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
max_over_mels = log_spec.max(dim=1, keepdim=True)[0]
max_log_spec = max_over_mels.max(dim=2, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_log_spec - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec

def log_mel(self, wav: torch.Tensor, shift: int | None = None):
spec = self.spec_transform(wav)[..., :-1]
if shift and shift != 0:
spec = self.shift_spec(spec, shift)
mel_spec = self.mel_transform(spec)

# Norm
log_spec = self.norm_mel(mel_spec)

return log_spec

def forward(self, wav: torch.Tensor, shift: int = 0):
# Reverb & noise
wav = self.aug_wav(wav)

# Spec & pitch shift
log_mel = self.log_mel(wav, shift)

# Spec aug
if random.random() > 0.2:
log_mel = self.spec_aug(log_mel)

return log_mel
70 changes: 36 additions & 34 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,61 @@
import shutil
import orjson
import torch
import torchaudio

from multiprocessing import Pool

from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import (
log_mel_spectrogram,
pad_or_trim,
N_FRAMES,
)
from amt.audio import pad_or_trim

config = load_config()
STRIDE_FACTOR = config["data"]["stride_factor"]


def get_features(
def get_wav_mid_segments(
audio_path: str, mid_path: str = "", return_json: bool = False
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()
n_mels = config["audio"]["n_mels"]
config = load_config()
stride_factor = config["data"]["stride_factor"]
sample_rate = config["audio"]["sample_rate"]
chunk_len = config["audio"]["chunk_len"]
num_samples = sample_rate * chunk_len
samples_per_ms = sample_rate // 1000

if not os.path.isfile(audio_path):
return None

# Load midi if required
if mid_path == "":
pass
midi_dict = None
elif not os.path.isfile(mid_path):
return None

try:
log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels)
if mid_path != "":
midi_dict = MidiDict.from_midi(mid_path)
else:
midi_dict = None
except Exception as e:
print("Failed to convert files into features")
raise e

_, total_frames = log_spec.shape
else:
midi_dict = MidiDict.from_midi(mid_path)

# Load audio
wav, sr = torchaudio.load(audio_path)
if sr != sample_rate:
wav = torchaudio.functional.resample(
waveform=wav,
orig_freq=sr,
new_freq=sample_rate,
).mean(0)

# Create features
total_samples = wav.shape[-1]
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
for idx in range(0, total_samples, num_samples // stride_factor):
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
if midi_dict is not None:
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
start_ms=idx // samples_per_ms,
end_ms=(idx + num_samples) / samples_per_ms,
)
else:
mid_feature = []
Expand All @@ -70,7 +72,7 @@ def get_features(

def write_features(args):
audio_path, mid_path, save_path = args
features = get_features(
features = get_wav_mid_segments(
audio_path=audio_path,
mid_path=mid_path,
return_json=False,
Expand All @@ -79,10 +81,10 @@ def write_features(args):
proc_save_path = os.path.join(dirname, str(os.getpid()) + basename)

with open(proc_save_path, mode="ab") as file:
for mel, seq in features:
for wav, seq in features:
file.write(
orjson.dumps(
mel.numpy(),
wav.numpy(),
option=orjson.OPT_SERIALIZE_NUMPY,
)
)
Expand All @@ -97,7 +99,7 @@ class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_path: str):
self.tokenizer = AmtTokenizer(return_tensors=True)
self.config = load_config()["data"]
self.aug_fn = self.tokenizer.export_msg_mixup()
self.mixup_fn = self.tokenizer.export_msg_mixup()
self.file_buff = open(load_path, mode="r")
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
Expand Down Expand Up @@ -126,11 +128,11 @@ def _format(tok):
self.file_mmap.seek(self.index[idx])

# Load data from line
mel = torch.tensor(orjson.loads(self.file_mmap.readline()))
wav = torch.tensor(orjson.loads(self.file_mmap.readline()))
_seq = orjson.loads(self.file_mmap.readline())

_seq = [_format(tok) for tok in _seq] # Format seq
_seq = self.aug_fn(_seq) # Data augmentation
_seq = self.mixup_fn(_seq) # Data augmentation

src = self.tokenizer.trunc_seq(
seq=_seq,
Expand All @@ -141,7 +143,7 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return mel, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def _build_index(self):
self.file_mmap.seek(0)
Expand Down
Loading

0 comments on commit f6f5fbb

Please sign in to comment.