diff --git a/amt/audio.py b/amt/audio.py index d03b1a3..7648c13 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -1,15 +1,19 @@ """Contains code taken from https://github.com/openai/whisper""" import os +import random import torch -import numpy as np +import torchaudio import torch.nn.functional as F +import torchaudio.functional as AF +import numpy as np from functools import lru_cache from subprocess import CalledProcessError, run from typing import Optional, Union from amt.config import load_config +from amt.tokenizer import AmtTokenizer # hard-coded audio hyperparameters config = load_config()["audio"] @@ -176,3 +180,192 @@ def log_mel_spectrogram( log_spec = (log_spec + 4.0) / 4.0 return log_spec + + +class AudioTransform(torch.nn.Module): + def __init__( + self, + reverb_factor: int = 1, + min_snr: int = 10, + max_snr: int = 40, + ): + super().__init__() + self.tokenizer = AmtTokenizer() + self.reverb_factor = reverb_factor + self.min_snr = min_snr + self.max_snr = max_snr + + self.config = load_config()["audio"] + self.sample_rate = self.config["sample_rate"] + self.chunk_len = self.config["chunk_len"] + self.num_samples = self.sample_rate * self.chunk_len + + # Audio aug + impulse_paths = self._get_paths( + os.path.join(os.path.dirname(__file__), "assets", "impulse") + ) + noise_paths = self._get_paths( + os.path.join(os.path.dirname(__file__), "assets", "noise") + ) + + # Register impulses and noises as buffers + self.num_impulse = 0 + for i, impulse in enumerate(self._get_impulses(impulse_paths)): + self.register_buffer(f"impulse_{i}", impulse) + self.num_impulse += 1 + + self.num_noise = 0 + for i, noise in enumerate(self._get_noise(noise_paths)): + self.register_buffer(f"noise_{i}", noise) + self.num_noise += 1 + + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.config["n_fft"], + hop_length=self.config["hop_len"], + ) + self.mel_transform = torchaudio.transforms.MelScale( + n_mels=self.config["n_mels"], + sample_rate=self.config["sample_rate"], + n_stft=self.config["n_fft"] // 2 + 1, + ) + self.spec_aug = torch.nn.Sequential( + torchaudio.transforms.FrequencyMasking( + freq_mask_param=15, iid_masks=True + ), + torchaudio.transforms.TimeMasking( + time_mask_param=500, iid_masks=True + ), + ) + + def _get_paths(self, dir_path): + return [ + os.path.join(dir_path, f) + for f in os.listdir(dir_path) + if os.path.isfile(os.path.join(dir_path, f)) + ] + + def _get_impulses(self, impulse_paths: list): + impulses = [torchaudio.load(path) for path in impulse_paths] + impulses = [ + AF.resample( + waveform=wav, orig_freq=sr, new_freq=config["sample_rate"] + ).mean(0, keepdim=True)[:, : 5 * self.sample_rate] + for wav, sr in impulses + ] + return [ + (wav) / (torch.linalg.vector_norm(wav, ord=2)) for wav in impulses + ] + + def _get_noise(self, noise_paths: list): + noises = [torchaudio.load(path) for path in noise_paths] + noises = [ + AF.resample( + waveform=wav, orig_freq=sr, new_freq=config["sample_rate"] + ).mean(0, keepdim=True)[:, : self.num_samples] + for wav, sr in noises + ] + + for wav in noises: + assert wav.shape[-1] == self.num_samples, "noise wav too short" + + return noises + + def apply_reverb(self, wav: torch.Tensor): + # wav: (bz, L) + batch_size, _ = wav.shape + + reverb_strength = ( + torch.Tensor([random.uniform(0, 1) for _ in range(batch_size)]) + .unsqueeze(-1) + .to(wav.device) + ) + reverb_type = random.randint(0, self.num_impulse - 1) + impulse = getattr(self, f"impulse_{reverb_type}") + + reverb = AF.fftconvolve(wav, impulse, mode="full")[ + :, : self.num_samples + ] + if self.reverb_factor > 1: + for _ in range(self.reverb_factor - 1): + reverb = AF.fftconvolve(reverb, impulse, mode="full")[ + : self.num_samples + ] + + res = (reverb_strength * reverb) + ((1 - reverb_strength) * wav) + + return res + + def apply_noise(self, wav: torch.tensor): + batch_size, _ = wav.shape + + snr_dbs = torch.tensor( + [ + random.randint(self.min_snr, self.max_snr) + for _ in range(batch_size) + ] + ).to(wav.device) + noise_type = random.randint(0, self.num_noise - 1) + noise = getattr(self, f"noise_{noise_type}") + + return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs) + + def shift_spec(self, specs: torch.Tensor, shift: int): + if shift == 0: + return specs + + freq_mult = 2 ** (shift / 12.0) + _, num_bins, L = specs.shape + new_num_bins = int(num_bins * freq_mult) + + # Interpolate expects extra channel dim + specs = specs.unsqueeze(1) + shifted_specs = torch.nn.functional.interpolate( + specs, size=(new_num_bins, L), mode="bilinear", align_corners=False + ) + shifted_specs = shifted_specs.squeeze(1) + + if shift > 0: + shifted_specs = shifted_specs[:, :num_bins, :] + else: + padding = num_bins - shifted_specs.size(1) + shifted_specs = torch.nn.functional.pad( + shifted_specs, (0, 0, 0, padding), "constant", 0 + ) + + return shifted_specs + + def aug_wav(self, wav: torch.Tensor): + return self.apply_reverb(self.apply_noise(wav)) + + def norm_mel(self, mel_spec: torch.Tensor): + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + max_over_mels = log_spec.max(dim=1, keepdim=True)[0] + max_log_spec = max_over_mels.max(dim=2, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_log_spec - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec + + def log_mel(self, wav: torch.Tensor, shift: int | None = None): + spec = self.spec_transform(wav)[..., :-1] + if shift and shift != 0: + spec = self.shift_spec(spec, shift) + mel_spec = self.mel_transform(spec) + + # Norm + log_spec = self.norm_mel(mel_spec) + + return log_spec + + def forward(self, wav: torch.Tensor, shift: int = 0): + # Reverb & noise + wav = self.aug_wav(wav) + + # Spec & pitch shift + log_mel = self.log_mel(wav, shift) + + # Spec aug + if random.random() > 0.2: + log_mel = self.spec_aug(log_mel) + + return log_mel diff --git a/amt/data.py b/amt/data.py index f10bbaa..777d1f5 100644 --- a/amt/data.py +++ b/amt/data.py @@ -3,23 +3,17 @@ import shutil import orjson import torch +import torchaudio from multiprocessing import Pool from aria.data.midi import MidiDict from amt.tokenizer import AmtTokenizer from amt.config import load_config -from amt.audio import ( - log_mel_spectrogram, - pad_or_trim, - N_FRAMES, -) +from amt.audio import pad_or_trim -config = load_config() -STRIDE_FACTOR = config["data"]["stride_factor"] - -def get_features( +def get_wav_mid_segments( audio_path: str, mid_path: str = "", return_json: bool = False ): """This function yields tuples of matched log mel spectrograms and @@ -27,35 +21,43 @@ def get_features( then it will return an empty list for the mid_feature """ tokenizer = AmtTokenizer() - n_mels = config["audio"]["n_mels"] + config = load_config() + stride_factor = config["data"]["stride_factor"] + sample_rate = config["audio"]["sample_rate"] + chunk_len = config["audio"]["chunk_len"] + num_samples = sample_rate * chunk_len + samples_per_ms = sample_rate // 1000 if not os.path.isfile(audio_path): return None + # Load midi if required if mid_path == "": - pass + midi_dict = None elif not os.path.isfile(mid_path): return None - - try: - log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels) - if mid_path != "": - midi_dict = MidiDict.from_midi(mid_path) - else: - midi_dict = None - except Exception as e: - print("Failed to convert files into features") - raise e - - _, total_frames = log_spec.shape + else: + midi_dict = MidiDict.from_midi(mid_path) + + # Load audio + wav, sr = torchaudio.load(audio_path) + if sr != sample_rate: + wav = torchaudio.functional.resample( + waveform=wav, + orig_freq=sr, + new_freq=sample_rate, + ).mean(0) + + # Create features + total_samples = wav.shape[-1] res = [] - for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR): - audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES) + for idx in range(0, total_samples, num_samples // stride_factor): + audio_feature = pad_or_trim(wav[idx:], length=num_samples) if midi_dict is not None: mid_feature = tokenizer._tokenize_midi_dict( midi_dict=midi_dict, - start_ms=start_frame * 10, - end_ms=(start_frame + N_FRAMES) * 10, + start_ms=idx // samples_per_ms, + end_ms=(idx + num_samples) / samples_per_ms, ) else: mid_feature = [] @@ -70,7 +72,7 @@ def get_features( def write_features(args): audio_path, mid_path, save_path = args - features = get_features( + features = get_wav_mid_segments( audio_path=audio_path, mid_path=mid_path, return_json=False, @@ -79,10 +81,10 @@ def write_features(args): proc_save_path = os.path.join(dirname, str(os.getpid()) + basename) with open(proc_save_path, mode="ab") as file: - for mel, seq in features: + for wav, seq in features: file.write( orjson.dumps( - mel.numpy(), + wav.numpy(), option=orjson.OPT_SERIALIZE_NUMPY, ) ) @@ -97,7 +99,7 @@ class AmtDataset(torch.utils.data.Dataset): def __init__(self, load_path: str): self.tokenizer = AmtTokenizer(return_tensors=True) self.config = load_config()["data"] - self.aug_fn = self.tokenizer.export_msg_mixup() + self.mixup_fn = self.tokenizer.export_msg_mixup() self.file_buff = open(load_path, mode="r") self.file_mmap = mmap.mmap( self.file_buff.fileno(), 0, access=mmap.ACCESS_READ @@ -126,11 +128,11 @@ def _format(tok): self.file_mmap.seek(self.index[idx]) # Load data from line - mel = torch.tensor(orjson.loads(self.file_mmap.readline())) + wav = torch.tensor(orjson.loads(self.file_mmap.readline())) _seq = orjson.loads(self.file_mmap.readline()) _seq = [_format(tok) for tok in _seq] # Format seq - _seq = self.aug_fn(_seq) # Data augmentation + _seq = self.mixup_fn(_seq) # Data augmentation src = self.tokenizer.trunc_seq( seq=_seq, @@ -141,7 +143,7 @@ def _format(tok): seq_len=self.config["max_seq_len"], ) - return mel, self.tokenizer.encode(src), self.tokenizer.encode(tgt) + return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt) def _build_index(self): self.file_mmap.seek(0) diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 3a8d58b..09a22ea 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -1,7 +1,9 @@ import random import os import copy +import functools +from torch import Tensor from collections import defaultdict from aria.data.midi import MidiDict, get_duration_ms @@ -395,3 +397,38 @@ def round_to_base(n, base=150): return res return msg_mixup + + def export_tensor_pitch_aug(self): + def tensor_pitch_aug( + seq: Tensor, + shift: int, + tok_to_id: dict, + id_to_tok: dict, + pad_tok: str, + unk_tok: str, + ): + """This acts on (batched) tensors, applying pitch aug in place""" + if shift == 0: + return seq + + batch_size, seq_len = seq.shape + for i in range(batch_size): + for j in range(seq_len): + tok = id_to_tok[seq[i, j].item()] + if type(tok) is tuple and tok[0] in {"on", "off"}: + msg_type, pitch = tok + seq[i, j] = tok_to_id.get( + (msg_type, pitch + shift), unk_tok + ) + elif tok == pad_tok: + break + + return seq + + return functools.partial( + tensor_pitch_aug, + tok_to_id=self.tok_to_id, + id_to_tok=self.id_to_tok, + pad_tok=self.pad_tok, + unk_tok=self.unk_tok, + ) diff --git a/amt/train.py b/amt/train.py index ed473bb..fcb7e39 100644 --- a/amt/train.py +++ b/amt/train.py @@ -1,6 +1,8 @@ import os import sys import csv +import random +import functools import argparse import logging import torch @@ -18,6 +20,7 @@ from amt.tokenizer import AmtTokenizer from amt.model import AmtEncoderDecoder, ModelConfig +from amt.audio import AudioTransform from amt.data import AmtDataset from amt.config import load_model_config from aria.utils import _load_weight @@ -217,10 +220,26 @@ def get_dataloaders( f"Loaded datasets with length: train={len(train_dataset)}; val={len(val_dataset)}" ) + # Pitch aug (to the sequence tensors) must be applied in the train + # dataloader as it needs to be done to every element in the batch equally. + # Having this code running on the main process was causing a bottlekneck. + tensor_pitch_aug = AmtTokenizer().export_tensor_pitch_aug() + + def _collate_fn(seqs, max_pitch_shift: int): + wav, src, tgt = torch.utils.data.default_collate(seqs) + + # Pitch aug + pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift) + src = tensor_pitch_aug(seq=src, shift=pitch_shift) + tgt = tensor_pitch_aug(seq=tgt, shift=pitch_shift) + + return wav, src, tgt, pitch_shift + train_dataloader = DataLoader( train_dataset, batch_size=batch_size, num_workers=num_workers, + collate_fn=functools.partial(_collate_fn, max_pitch_shift=5), shuffle=True, ) val_dataloader = DataLoader( @@ -248,6 +267,7 @@ def _train( model: AmtEncoderDecoder, train_dataloader: DataLoader, val_dataloader: DataLoader, + audio_transform: AudioTransform, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, steps_per_checkpoint: int | None = None, @@ -258,7 +278,9 @@ def _train( def profile_flops(dataloader: DataLoader): def _bench(): for batch in dataloader: - mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) + wav, src, tgt, pitch_shift = batch + with torch.no_grad(): + mel = audio_transform.forward(wav, shift=pitch_shift) logits = model(mel, src) # (b_sz, s_len, v_sz) logits = logits.transpose(1, 2) loss = loss_fn(logits, tgt) @@ -268,19 +290,18 @@ def _bench(): optimizer.zero_grad() break - flop_counter = FlopCounterMode(display=False) logger.info( f"Model has " f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " "parameters" ) - logger.info("Profiling FLOP") + logger.info("Compiling model...") _bench() - with flop_counter: - _bench() - total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) - logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") + # with flop_counter: + # _bench() + # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) + # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") def make_checkpoint(_accelerator, _epoch: int, _step: int): checkpoint_dir = os.path.join( @@ -314,8 +335,6 @@ def train_loop( lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) model.train() - of_batch_exists = False - for __step, batch in ( pbar := tqdm( enumerate(dataloader), @@ -326,14 +345,9 @@ def train_loop( ): step = __step + _resume_step + 1 - # Code for forcing overfitting - # if (overfit is True) and (of_batch_exists is True): - # pass - # else: - # of_batch_exists = True - # mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - - mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) + wav, src, tgt, pitch_shift = batch + with torch.no_grad(): + mel = audio_transform.forward(wav, shift=pitch_shift) logits = model(mel, src) # (b_sz, s_len, v_sz) logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -399,8 +413,9 @@ def val_loop(dataloader, _epoch: int): leave=False, ) ): - mel, src, tgt = batch + wav, src, tgt = batch with torch.no_grad(): + mel = audio_transform.log_mel(wav) logits = model(mel, src) logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -541,6 +556,8 @@ def resume_train( model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) + model = torch.compile(model) + audio_transform = AudioTransform().to(accelerator.device) logger.info(f"Loaded model with config: {load_model_config(model_name)}") train_dataloader, val_dataloader = get_dataloaders( @@ -600,6 +617,7 @@ def resume_train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + audio_transform=audio_transform, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, @@ -655,8 +673,8 @@ def train( model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) - # logger.info("Compiling model...") - # model = torch.compile(model) + model = torch.compile(model) + audio_transform = AudioTransform().to(accelerator.device) logger.info(f"Loaded model with config: {load_model_config(model_name)}") if mode == "finetune": try: @@ -716,6 +734,7 @@ def train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + audio_transform=audio_transform, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, diff --git a/config/config.json b/config/config.json index be4cfd7..67c407e 100644 --- a/config/config.json +++ b/config/config.json @@ -17,7 +17,7 @@ "n_mels": 256 }, "data": { - "stride_factor": 1, + "stride_factor": 3, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index b3ca909..b4eb9a1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1 +1,2 @@ -black \ No newline at end of file +black +matplotlib \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 4ba3d4a..5f54124 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,10 +1,13 @@ import unittest import logging import os -import time +import torch +import torchaudio +import matplotlib.pyplot as plt -from amt.data import get_features, AmtDataset +from amt.data import get_wav_mid_segments, AmtDataset from amt.tokenizer import AmtTokenizer +from amt.audio import AudioTransform, log_mel_spectrogram from aria.data.midi import MidiDict @@ -17,8 +20,8 @@ # Need to test this properly, have issues turning mel_spec back into audio class TestDataGen(unittest.TestCase): - def test_feature_gen(self): - for log_spec, seq in get_features( + def test_wav_mid_segments(self): + for log_spec, seq in get_wav_mid_segments( audio_path="tests/test_data/147.wav", mid_path="tests/test_data/147.mid", ): @@ -41,8 +44,8 @@ def test_build(self): dataset = AmtDataset("tests/test_results/dataset.jsonl") tokenizer = AmtTokenizer() - for idx, (spec, src, tgt) in enumerate(dataset): - print(spec.shape, src.shape, tgt.shape) + for idx, (wav, src, tgt) in enumerate(dataset): + print(wav.shape, src.shape, tgt.shape) src_decoded = tokenizer.decode(src) tgt_decoded = tokenizer.decode(tgt) self.assertListEqual(src_decoded[1:], tgt_decoded[:-1]) @@ -58,7 +61,7 @@ def test_maestro(self): tokenizer = AmtTokenizer() dataset = AmtDataset(load_path=MAESTRO_PATH) - for idx, (mel, src, tgt) in enumerate(dataset): + for idx, (wav, src, tgt) in enumerate(dataset): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) if (idx + 1) % 100 == 0: break @@ -78,5 +81,86 @@ def test_maestro(self): self.assertEqual(src_tok, tgt_tok) +class TestAug(unittest.TestCase): + def plot_spec(self, mel: torch.Tensor, name: str | int): + plt.figure(figsize=(10, 4)) + plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") + plt.colorbar(format="%+2.0f dB") + plt.title("(mel)-Spectrogram") + plt.tight_layout() + plt.savefig(f"tests/test_results/{name}.png") + plt.close() + + def test_spec(self): + SAMPLE_RATE, CHUNK_LEN = 16000, 30 + audio_transform = AudioTransform() + wav, sr = torchaudio.load("tests/test_data/maestro.wav") + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( + 0, keepdim=True + )[:, : SAMPLE_RATE * CHUNK_LEN] + + griffin_lim = torchaudio.transforms.GriffinLim( + n_fft=2048, + hop_length=160, + power=1, + n_iter=64, + ) + + spec = audio_transform.spec_transform(wav) + shift_spec = audio_transform.shift_spec(spec, 1) + shift_wav = griffin_lim(shift_spec) + torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) + torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) + + log_mel = log_mel_spectrogram(wav) + self.plot_spec(log_mel.squeeze(0), "orig") + + _mel = audio_transform.mel_transform(spec) + _log_mel = audio_transform.norm_mel(_mel) + self.plot_spec(_log_mel.squeeze(0), "new") + + def test_pitch_aug(self): + tokenizer = AmtTokenizer(return_tensors=True) + tensor_pitch_aug_fn = tokenizer.export_tensor_pitch_aug() + mid_dict = MidiDict.from_midi("tests/test_data/maestro2.mid") + seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000) + src = tokenizer.encode(tokenizer.trunc_seq(seq, 4096)) + tgt = tokenizer.encode(tokenizer.trunc_seq(seq[1:], 4096)) + + src = torch.stack((src, src, src)) + tgt = torch.stack((tgt, tgt, tgt)) + src_aug = tensor_pitch_aug_fn(src.clone(), shift=1) + tgt_aug = tensor_pitch_aug_fn(tgt.clone(), shift=1) + + src_aug_dec = tokenizer.decode(src_aug[1]) + tgt_aug_dec = tokenizer.decode(tgt_aug[2]) + print(seq[:20]) + print(src_aug_dec[:20]) + print(tgt_aug_dec[:20]) + + for tok, aug_tok in zip(seq, src_aug_dec): + if type(tok) is tuple and aug_tok[0] in {"on", "off"}: + self.assertEqual(tok[1] + 1, aug_tok[1]) + + for src_tok, tgt_tok in zip(src_aug_dec[1:], tgt_aug_dec): + self.assertEqual(src_tok, tgt_tok) + + def test_mels(self): + SAMPLE_RATE, CHUNK_LEN = 16000, 30 + audio_transform = AudioTransform() + wav, sr = torchaudio.load("tests/test_data/maestro.wav") + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( + 0, keepdim=True + )[:, : SAMPLE_RATE * CHUNK_LEN] + wav_aug = audio_transform.aug_wav(wav) + torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) + torchaudio.save("tests/test_results/aug.wav", wav_aug, SAMPLE_RATE) + + wavs = torch.stack((wav[0], wav[0], wav[0])) + mels = audio_transform(wavs) + for idx in range(mels.shape[0]): + self.plot_spec(mels[idx], idx) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_data/maestro.wav b/tests/test_data/maestro.wav new file mode 100644 index 0000000..ad2279b Binary files /dev/null and b/tests/test_data/maestro.wav differ diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 06b24c0..64c1a36 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,9 +1,9 @@ import unittest import logging +import torch import os from amt.tokenizer import AmtTokenizer -from aria.tokenizer import AbsTokenizer from aria.data.midi import MidiDict logging.basicConfig(level=logging.INFO) @@ -11,9 +11,6 @@ os.mkdir("tests/test_results") -# Add test for unk tok - - class TestAmtTokenizer(unittest.TestCase): def test_tokenize(self): def _tokenize_detokenize(mid_name: str, start: int, end: int): @@ -40,6 +37,45 @@ def _tokenize_detokenize(mid_name: str, start: int, end: int): _tokenize_detokenize("maestro2.mid", start=START, end=END) _tokenize_detokenize("maestro3.mid", start=START, end=END) + def test_pitch_aug(self): + tokenizer = AmtTokenizer(return_tensors=True) + + midi_dict_1 = MidiDict.from_midi("tests/test_data/maestro1.mid") + midi_dict_2 = MidiDict.from_midi("tests/test_data/maestro2.mid") + midi_dict_3 = MidiDict.from_midi("tests/test_data/maestro3.mid") + seq_1 = tokenizer._tokenize_midi_dict(midi_dict_1, 0, 30000) + seq_1 = tokenizer.trunc_seq(seq_1, 2048) + seq_2 = tokenizer.trunc_seq( + tokenizer._tokenize_midi_dict(midi_dict_2, 0, 30000), 2048 + ) + seq_2 = tokenizer.trunc_seq(seq_2, 2048) + seq_3 = tokenizer.trunc_seq( + tokenizer._tokenize_midi_dict(midi_dict_3, 0, 30000), 2048 + ) + seq_3 = tokenizer.trunc_seq(seq_3, 2048) + + seqs = torch.stack( + ( + tokenizer.encode(seq_1), + tokenizer.encode(seq_2), + tokenizer.encode(seq_3), + ) + ) + aug_seqs = tokenizer.pitch_aug(seqs, shift=2) + + midi_dict_1_aug = tokenizer._detokenize_midi_dict( + tokenizer.decode(aug_seqs[0]), 30000 + ) + midi_dict_2_aug = tokenizer._detokenize_midi_dict( + tokenizer.decode(aug_seqs[1]), 30000 + ) + midi_dict_3_aug = tokenizer._detokenize_midi_dict( + tokenizer.decode(aug_seqs[2]), 30000 + ) + midi_dict_1_aug.to_midi().save("tests/test_results/pitch1.mid") + midi_dict_2_aug.to_midi().save("tests/test_results/pitch2.mid") + midi_dict_3_aug.to_midi().save("tests/test_results/pitch3.mid") + def test_aug(self): def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): _tokenized_seq = tokenizer._tokenize_midi_dict(