Skip to content

Commit

Permalink
Implement training, inference, and CLI (#6)
Browse files Browse the repository at this point in the history
* skeleton

* convert train script from aria - not tested

* msg

* add cli

* fix

* fix

* fix log

* fix dataset padding

* add bos tok

* remove mp test

* fix modelconfig

* add infer

* add sample cli

* fix

* implement training and inference

* fix format

---------

Co-authored-by: Louis Bradshaw <[email protected]>
  • Loading branch information
loubbrad and Louis Bradshaw authored Feb 20, 2024
1 parent 4a9a70d commit c814a66
Show file tree
Hide file tree
Showing 14 changed files with 1,365 additions and 55 deletions.
89 changes: 62 additions & 27 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
STRIDE_FACTOR = config["stride_factor"]


def get_features(audio_path: str, mid_path: str):
def get_features(audio_path: str, mid_path: str | None = None):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list).
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()

if not os.path.isfile(audio_path) or not os.path.isfile(mid_path):
if not os.path.isfile(audio_path):
return None
if (mid_path is not None) and (not os.path.isfile(mid_path)):
return None

try:
midi_dict = MidiDict.from_midi(mid_path)
log_spec = log_mel_spectrogram(audio=audio_path)
if mid_path is not None:
midi_dict = MidiDict.from_midi(mid_path)
else:
midi_dict = None
except Exception as e:
print("Failed to convert files into features")
return None
Expand All @@ -40,19 +46,27 @@ def get_features(audio_path: str, mid_path: str):
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
)
if midi_dict:
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
)
else:
mid_feature = []

res.append((audio_feature, mid_feature))

return res


def get_features_mp(args):
"""Multiprocessing wrapper for get_features"""
res = get_features(*args)
try:
res = get_features(*args)
except Exception as e:
res = None

if res is None:
return False, None
else:
Expand All @@ -62,6 +76,7 @@ def get_features_mp(args):
class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_path: str):
self.tokenizer = AmtTokenizer(return_tensors=True)
self.config = load_config()["data"]
self.aug_fn = self.tokenizer.export_msg_mixup()
self.file_buff = open(load_path, mode="r")
self.file_mmap = mmap.mmap(
Expand Down Expand Up @@ -91,14 +106,22 @@ def _format(tok):
self.file_mmap.seek(self.index[idx])

# This isn't going to load properly
spec, seq = json.loads(self.file_mmap.readline()) # Load data from line
spec, _seq = json.loads(
self.file_mmap.readline()
) # Load data from line

spec = torch.tensor(spec) # Format spectrogram into tensor
seq = [_format(tok) for tok in seq] # Format seq
seq = self.aug_fn(seq) # Data augmentation
_seq = [_format(tok) for tok in _seq] # Format seq
_seq = self.aug_fn(_seq) # Data augmentation

src = seq
tgt = seq[1:] + [self.tokenizer.pad_tok]
src = self.tokenizer.trunc_seq(
seq=_seq,
seq_len=self.config["max_seq_len"],
)
tgt = self.tokenizer.trunc_seq(
seq=_seq[1:],
seq_len=self.config["max_seq_len"],
)

return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

Expand All @@ -120,20 +143,32 @@ def build(
cls,
matched_load_paths: list[tuple[str, str]],
save_path: str,
audio_aug_hook: Callable | None = None,
num_processes: int = 4,
):
def _get_features(_matched_load_paths: list):
with Pool(4) as pool:
results = pool.imap(get_features_mp, _matched_load_paths)
num_paths = len(_matched_load_paths)
for idx, (success, res) in enumerate(results):
if idx % 50 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")

if success == False:
continue
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature
num_paths = len(_matched_load_paths)
for idx, entry in enumerate(_matched_load_paths):
success, res = get_features_mp(entry)
if idx % 10 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")
if success == False:
continue
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature

# MP CODE DOESN'T WORK FOR SOME REASON !!

# with Pool(num_processes) as pool:
# results = pool.imap(get_features_mp, _matched_load_paths)
# num_paths = len(_matched_load_paths)
# for idx, (success, res) in enumerate(results):
# if idx % 10 == 0 and idx != 0:
# print(f"Processed audio-mid pairs: {idx}/{num_paths}")

# if success == False:
# continue
# for _audio_feature, _mid_feature in res:
# yield _audio_feature.tolist(), _mid_feature

with jsonlines.open(save_path, mode="w") as writer:
for audio_feature, mid_feature in _get_features(matched_load_paths):
Expand Down
112 changes: 112 additions & 0 deletions amt/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import random
import torch

from tqdm import tqdm

from amt.model import AmtEncoderDecoder
from amt.tokenizer import AmtTokenizer
from amt.data import get_features
from amt.config import load_config
from aria.data.midi import MidiDict


# TODO: Implement this with KV-caching, see the whisper inference file


def greedy_sample(
model: AmtEncoderDecoder,
audio_path: str,
device: str,
):
LEN_MS = 30000 # This should not be hardcoded
MAX_SEQ_LEN = model.dims.n_text_ctx

def _process_segment(
audio_seg: torch.tensor,
prefix: list,
model: AmtEncoderDecoder,
tokenizer: AmtTokenizer = AmtTokenizer(),
):
start_idx = len(prefix)
pad_id = tokenizer.pad_id
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok]
audio_seg = audio_seg.unsqueeze(0).to(device)
seq = tokenizer.encode(tokenizer.trunc_seq(prefix, MAX_SEQ_LEN))
seq = torch.tensor(seq).unsqueeze(0).to(device)

for idx in (
pbar := tqdm(
range(start_idx, MAX_SEQ_LEN - 1),
total=MAX_SEQ_LEN - (start_idx + 1),
leave=False,
)
):
logits = model.forward(mel=audio_seg, tokens=seq[:, :idx])
probs = torch.softmax(logits[0, -1], dim=-1)
next_tok_id = torch.multinomial(probs / 0.001, num_samples=1)

# Debug logging:
# print(f"input seq shape: {seq[:, :idx].shape}")
# print(f"logits shape: {logits.shape}")
# print(f"probs shape: {probs.shape}")
# print(int(next_tok_id), tokenizer.id_to_tok[int(next_tok_id)])

if next_tok_id == pad_id or next_tok_id == eos_id:
break
else:
seq[0, idx] = next_tok_id

if idx == MAX_SEQ_LEN - 2:
print("WARNING: Ran out of context when generating sequence")

seq = tokenizer.decode(seq[0, :])
_, unclosed_notes = tokenizer._detokenize_midi_dict(
tokenized_seq=seq,
len_ms=LEN_MS,
return_unclosed_notes=True,
)

return seq, unclosed_notes

audio_segments = [f for f, _ in get_features(audio_path=audio_path)]
print(f"{len(audio_segments)} audio segments to process...")

model.to(device)
model.eval()
tokenizer = AmtTokenizer()
_unclosed_notes = []
concat_seq = []
_onset_adj = 0
for idx, _audio_seg in enumerate(audio_segments):
_seq = [("prev", p) for p in _unclosed_notes] + [tokenizer.bos_tok]

_seq, _unclosed_notes = _process_segment(
audio_seg=_audio_seg,
prefix=_seq,
model=model,
tokenizer=tokenizer,
)
random.shuffle(_unclosed_notes)

# DEBUG
__midi_dict = tokenizer._detokenize_midi_dict(_seq, 30000)
__midi = __midi_dict.to_midi()
__midi.save(f"/weka/proj-aria/aria-amt/samples/res{idx}.mid")

print(f"Done {idx}/{len(audio_segments)}:\n{_seq}")

for tok in _seq:
if type(tok) is tuple and tok[0] == "onset":
_onset_orig = tok[1]
_onset_adj = _onset_orig + (idx * LEN_MS)
concat_seq.append(("onset", _onset_adj))
elif tok is tokenizer.pad_tok:
break
else:
concat_seq.append(tok)

return tokenizer._detokenize_midi_dict(
tokenized_seq=concat_seq,
len_ms=_onset_adj,
)
37 changes: 28 additions & 9 deletions amt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

# TODO:
# Go through and make this more efficient using flash attention ect...


@dataclass
class ModelDimensions:
class ModelConfig:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
n_vocab: Optional[int] = None

def set_vocab_size(self, vocab_size: int):
self.n_vocab = vocab_size


class LayerNorm(nn.LayerNorm):
Expand Down Expand Up @@ -87,7 +93,21 @@ def forward(
k = kv_cache[self.key]
v = kv_cache[self.value]

# Use flash attention here !!
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# debug = True
# if debug is True:
# print(f"q shape: {q.shape}")
# print(f"k shape: {k.shape}")
# print(f"v shape: {v.shape}")
# print(f"mask shape: {mask.shape}")

wv, qk = self.qkv_attention(q, k, v, mask)

# if debug is True:
# print(f"att_out shape: {wv.shape}")
# print(f"att_weights shape: {qk.shape}")

return self.out(wv), qk

def qkv_attention(
Expand Down Expand Up @@ -174,7 +194,7 @@ def forward(self, x: Tensor):

assert (
x.shape[1:] == self.positional_embedding.shape
), "incorrect audio shape"
), f"incorrect audio shape: {x.shape[1:]} != {self.positional_embedding.shape}"
x = (x + self.positional_embedding).to(x.dtype)

for block in self.blocks:
Expand Down Expand Up @@ -229,8 +249,8 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
return logits


class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
class AmtEncoderDecoder(nn.Module):
def __init__(self, dims: ModelConfig):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
Expand Down Expand Up @@ -274,10 +294,9 @@ def embed_audio(self, mel: torch.Tensor):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)

def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
_buff = self.encoder(mel)
return self.decoder(tokens, _buff)

@property
def device(self):
Expand Down
Loading

0 comments on commit c814a66

Please sign in to comment.