Skip to content

Commit

Permalink
Dataset building and inference improvements (#10)
Browse files Browse the repository at this point in the history
* update spectrogram params

* fix inference and dataset building
  • Loading branch information
loubbrad authored Feb 28, 2024
1 parent 45ebd80 commit 4d12799
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 104 deletions.
Binary file modified amt/assets/mel_filters.npz
Binary file not shown.
4 changes: 2 additions & 2 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
assert n_mels in {80, 128, 256}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz"
Expand All @@ -127,7 +127,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:

def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
n_mels: int = 256,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
Expand Down
135 changes: 67 additions & 68 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import mmap
import os
import logging
import json
import jsonlines
import shutil
import orjson
import torch

from typing import Callable
from multiprocessing import Pool

from aria.data.midi import MidiDict
Expand All @@ -17,36 +15,19 @@
N_FRAMES,
)

config = load_config()["data"]
STRIDE_FACTOR = config["stride_factor"]
config = load_config()
STRIDE_FACTOR = config["data"]["stride_factor"]


def setup_logger():
# Get logger and reset all handlers
logger = logging.getLogger(__name__)
for h in logger.handlers[:]:
logger.removeHandler(h)

logger.propagate = False
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[%(asctime)s] %(name)s: [%(levelname)s] %(message)s",
)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger


def get_features(audio_path: str, mid_path: str = ""):
def get_features(
audio_path: str, mid_path: str = "", return_json: bool = False
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()
n_mels = config["audio"]["n_mels"]

if not os.path.isfile(audio_path):
return None
Expand All @@ -57,7 +38,7 @@ def get_features(audio_path: str, mid_path: str = ""):
return None

try:
log_spec = log_mel_spectrogram(audio=audio_path)
log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels)
if mid_path != "":
midi_dict = MidiDict.from_midi(mid_path)
else:
Expand All @@ -79,19 +60,37 @@ def get_features(audio_path: str, mid_path: str = ""):
else:
mid_feature = []

if return_json is True:
audio_feature = audio_feature.tolist()

res.append((audio_feature, mid_feature))

return res


def get_features_mp(args):
"""Multiprocessing wrapper for get_features"""
res = get_features(*args)
def write_features(args):
audio_path, mid_path, save_path = args
features = get_features(
audio_path=audio_path,
mid_path=mid_path,
return_json=False,
)
dirname, basename = os.path.split(save_path)
proc_save_path = os.path.join(dirname, str(os.getpid()) + basename)

with open(proc_save_path, mode="ab") as file:
for mel, seq in features:
file.write(
orjson.dumps(
mel.numpy(),
option=orjson.OPT_SERIALIZE_NUMPY,
)
)
file.write(b"\n")
file.write(orjson.dumps(seq))
file.write(b"\n")

if res is None:
return False, None
else:
return True, res
return proc_save_path


class AmtDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -127,9 +126,9 @@ def _format(tok):
self.file_mmap.seek(self.index[idx])

# Load data from line
spec, _seq = json.loads(self.file_mmap.readline())
mel = torch.tensor(orjson.loads(self.file_mmap.readline()))
_seq = orjson.loads(self.file_mmap.readline())

spec = torch.tensor(spec) # Format spectrogram into tensor
_seq = [_format(tok) for tok in _seq] # Format seq
_seq = self.aug_fn(_seq) # Data augmentation

Expand All @@ -142,15 +141,15 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return mel, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def _build_index(self):
self.file_mmap.seek(0)
index = []
while True:
pos = self.file_mmap.tell()
line_buffer = self.file_mmap.readline()
if line_buffer == b"":
self.file_mmap.readline()
if self.file_mmap.readline() == b"":
break
else:
index.append(pos)
Expand All @@ -162,33 +161,33 @@ def build(
cls,
matched_load_paths: list[tuple[str, str]],
save_path: str,
num_processes: int = 4,
num_processes: int = 1,
):
def _get_features(_matched_load_paths: list):
num_paths = len(_matched_load_paths)
for idx, entry in enumerate(_matched_load_paths):
success, res = get_features_mp(entry)
assert os.path.isfile(save_path) is False, f"{save_path} already exists"
num_paths = len(matched_load_paths)
with Pool(processes=num_processes) as pool:
sharded_save_paths = []
res = pool.imap_unordered(
write_features,
((ap, mp, save_path) for ap, mp in matched_load_paths),
)
for idx, proc_save_path in enumerate(res):
if idx % 10 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")
if success == False:
continue
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature

# MP CODE DOESN'T WORK FOR SOME REASON !!

# with Pool(num_processes) as pool:
# results = pool.imap(get_features_mp, _matched_load_paths)
# num_paths = len(_matched_load_paths)
# for idx, (success, res) in enumerate(results):
# if idx % 10 == 0 and idx != 0:
# print(f"Processed audio-mid pairs: {idx}/{num_paths}")

# if success == False:
# continue
# for _audio_feature, _mid_feature in res:
# yield _audio_feature.tolist(), _mid_feature

with jsonlines.open(save_path, mode="w") as writer:
for audio_feature, mid_feature in _get_features(matched_load_paths):
writer.write([audio_feature, mid_feature])
print(f"Finished {idx}/{num_paths}")
if proc_save_path not in sharded_save_paths:
sharded_save_paths.append(proc_save_path)

# This is bad, however cat is fast
if shutil.which("cat") is None:
print("The GNU cat command is not available")
else:
print("Concatinating sharded dataset files")
shell_cmd = f"cat "
for _path in sharded_save_paths:
shell_cmd += f"{_path} "
print()
shell_cmd += f">> {save_path}"

os.system(shell_cmd)
for _path in sharded_save_paths:
os.remove(_path)
30 changes: 15 additions & 15 deletions amt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# sort of branching to make sure that we don't miss notes, ect... Implement this
# next week -- Exciting problem (checkout other inference algos)

# Implement maximum note len =5s
# Implement either beam search or decoding initial onset note on first


def greedy_sample(
model: AmtEncoderDecoder,
Expand All @@ -38,6 +41,7 @@ def _process_segment(
audio_seg = audio_seg.unsqueeze(0).to(device)
seq = tokenizer.encode(tokenizer.trunc_seq(prefix, MAX_SEQ_LEN))
seq = torch.tensor(seq).unsqueeze(0).to(device)
audio_feature = model.embed_audio(mel=audio_seg)

for idx in (
pbar := tqdm(
Expand All @@ -46,21 +50,14 @@ def _process_segment(
leave=False,
)
):
logits = model.forward(mel=audio_seg, tokens=seq[:, :idx])
logits = model.logits(
audio_features=audio_feature, tokens=seq[:, :idx]
)
next_tok_id = torch.argmax(logits[0, -1], dim=-1)
# probs = torch.softmax(logits[0, -1], dim=-1)
# next_tok_id = torch.argmax(probs, dim=-1)

# Debug logging:
# print(f"input seq shape: {seq[:, :idx].shape}")
# print(f"logits shape: {logits.shape}")
# print(f"probs shape: {probs.shape}")
# print(int(next_tok_id), tokenizer.id_to_tok[int(next_tok_id)])

seq[0, idx] = next_tok_id
if next_tok_id == pad_id or next_tok_id == eos_id:
break
else:
seq[0, idx] = next_tok_id

if idx == MAX_SEQ_LEN - 2:
print("WARNING: Ran out of context when generating sequence")
Expand All @@ -81,7 +78,7 @@ def _process_segment(
model.eval()
tokenizer = AmtTokenizer()
_unclosed_notes = []
concat_seq = []
concat_seq = [tokenizer.bos_tok]
_onset_adj = 0
for idx, _audio_seg in enumerate(audio_segments):
_seq = [("prev", p) for p in _unclosed_notes] + [tokenizer.bos_tok]
Expand All @@ -99,14 +96,17 @@ def _process_segment(
__midi = __midi_dict.to_midi()
__midi.save(f"/weka/proj-aria/aria-amt/samples/res{idx}.mid")

print(f"Done {idx}/{len(audio_segments)}:\n{_seq}")

print(f"Done {idx + 1}/{len(audio_segments)}")
for tok in _seq:
if type(tok) is tuple and tok[0] == "onset":
_onset_orig = tok[1]
_onset_adj = _onset_orig + (idx * LEN_MS)
concat_seq.append(("onset", _onset_adj))
elif tok is tokenizer.pad_tok:
elif type(tok) is tuple and tok[0] == "prev":
continue
elif tok is tokenizer.bos_tok:
continue
elif tok is tokenizer.pad_tok or tok is tokenizer.eos_tok:
break
else:
concat_seq.append(tok)
Expand Down
15 changes: 9 additions & 6 deletions amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@ def build_maestro(args):

assert os.path.isdir(args.dir), "MAESTRO directory not found"
assert os.path.isfile(args.csv), "MAESTRO csv not found"
if (
os.path.isfile(args.train)
or os.path.isfile(args.val)
or os.path.isfile(args.test)
):
print("Dataset files already exist - overwriting")
if os.path.isfile(args.train):
print(f"Dataset file already exists at {args.train} - removing")
os.remove(args.train)
if os.path.isfile(args.val):
print(f"Dataset file already exists at {args.val} - removing")
os.remove(args.val)
if os.path.isfile(args.test):
print(f"Dataset file already exists at {args.test} - removing")
os.remove(args.test)

matched_paths_train = []
matched_paths_val = []
Expand Down
4 changes: 3 additions & 1 deletion amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def _detokenize_midi_dict(
if DEBUG:
raise Exception
else:

notes_to_close[tok_1_data] = (tok_2_data, tok_3_data)
elif tok_1_type == "off":
if tok_2_type != "onset":
Expand Down Expand Up @@ -322,6 +321,9 @@ def export_data_aug(self):

def export_msg_mixup(self):
def msg_mixup(src: list):
def round_to_base(n, base=150):
return base * round(n / base)

# Process bos, eos, and pad tokens
orig_len = len(src)
seen_pad_tok = False
Expand Down
9 changes: 6 additions & 3 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def _get_optim(
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=0.01,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1,
betas=(0.9, 0.98),
eps=1e-6,
)

warmup_lrs = torch.optim.lr_scheduler.LinearLR(
Expand Down Expand Up @@ -365,6 +365,9 @@ def train_loop(

# Backwards step
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
optimizer.zero_grad()
if scheduler:
Expand Down
7 changes: 4 additions & 3 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
},
"audio": {
"sample_rate": 16000,
"n_fft": 400,
"n_fft": 2048,
"hop_len": 160,
"chunk_len": 30
"chunk_len": 30,
"n_mels": 256
},
"data": {
"stride_factor": 3,
"stride_factor": 1,
"max_seq_len": 4096
}
}
2 changes: 1 addition & 1 deletion config/models/medium.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 512,
"n_audio_head": 8,
Expand Down
2 changes: 1 addition & 1 deletion config/models/small.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 384,
"n_audio_head": 6,
Expand Down
2 changes: 1 addition & 1 deletion config/models/test.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 64,
"n_audio_head": 4,
Expand Down
Loading

0 comments on commit 4d12799

Please sign in to comment.