Skip to content

Commit

Permalink
fix model
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 25, 2024
1 parent 7cae45b commit 01845ee
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 57 deletions.
29 changes: 25 additions & 4 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mmap
import os
import logging
import json
import jsonlines
import torch
Expand All @@ -20,7 +21,27 @@
STRIDE_FACTOR = config["stride_factor"]


def get_features(audio_path: str, mid_path: str | None = None):
def setup_logger():
# Get logger and reset all handlers
logger = logging.getLogger(__name__)
for h in logger.handlers[:]:
logger.removeHandler(h)

logger.propagate = False
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[%(asctime)s] %(name)s: [%(levelname)s] %(message)s",
)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger


def get_features(audio_path: str, mid_path: str = ""):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
Expand All @@ -30,14 +51,14 @@ def get_features(audio_path: str, mid_path: str | None = None):
if not os.path.isfile(audio_path):
return None

if mid_path is not None:
if mid_path == "":
pass
elif not os.path.isfile(mid_path):
return None

try:
log_spec = log_mel_spectrogram(audio=audio_path)
if mid_path is not None:
if mid_path != "":
midi_dict = MidiDict.from_midi(mid_path)
else:
midi_dict = None
Expand All @@ -49,7 +70,7 @@ def get_features(audio_path: str, mid_path: str | None = None):
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
if midi_dict:
if midi_dict is not None:
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
Expand Down
9 changes: 7 additions & 2 deletions amt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

# TODO: Implement this with KV-caching, see the whisper inference file

# Due to the autoregressive nature, a good inference algorithm should use some
# sort of branching to make sure that we don't miss notes, ect... Implement this
# next week -- Exciting problem (checkout other inference algos)


def greedy_sample(
model: AmtEncoderDecoder,
Expand Down Expand Up @@ -43,8 +47,9 @@ def _process_segment(
)
):
logits = model.forward(mel=audio_seg, tokens=seq[:, :idx])
probs = torch.softmax(logits[0, -1], dim=-1)
next_tok_id = torch.multinomial(probs / 0.001, num_samples=1)
next_tok_id = torch.argmax(logits[0, -1], dim=-1)
# probs = torch.softmax(logits[0, -1], dim=-1)
# next_tok_id = torch.argmax(probs, dim=-1)

# Debug logging:
# print(f"input seq shape: {seq[:, :idx].shape}")
Expand Down
68 changes: 23 additions & 45 deletions amt/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains code modified from https://github.com/openai/whisper"""

import math
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -8,9 +9,6 @@
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

# TODO:
# Go through and make this more efficient using flash attention ect...


@dataclass
class ModelConfig:
Expand All @@ -29,40 +27,20 @@ def set_vocab_size(self, vocab_size: int):
self.n_vocab = vocab_size


class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)


class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)


class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)


def sinusoids(length, channels, max_timescale=10000):
def sinusoids(
length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2)
)
scaled_time = (
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
)
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)


class MultiHeadAttention(nn.Module):
Expand All @@ -72,10 +50,10 @@ def __init__(self, n_state: int, n_head: int):

self.n_head = n_head
self.d_head = n_state // n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)

def forward(
self,
Expand Down Expand Up @@ -170,18 +148,18 @@ def __init__(
super().__init__()

self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.attn_ln = nn.LayerNorm(n_state)

self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None

n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
self.mlp_ln = nn.LayerNorm(n_state)

def forward(
self,
Expand All @@ -207,16 +185,16 @@ def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(
n_state, n_state, kernel_size=3, stride=2, padding=1
)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
self.ln_post = nn.LayerNorm(n_state)

def forward(self, x: Tensor):
"""
Expand Down Expand Up @@ -253,7 +231,7 @@ def __init__(
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
self.ln = nn.LayerNorm(n_state)

mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
Expand Down
5 changes: 3 additions & 2 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_pretrain_optim(
):
LR = 3e-4
END_RATIO = 0.1
WARMUP_STEPS = 200
WARMUP_STEPS = 500

return _get_optim(
lr=LR,
Expand Down Expand Up @@ -210,6 +210,7 @@ def get_dataloaders(
num_workers: int,
):
logger = get_logger(__name__)
logger.info("Indexing datasets...")
train_dataset = AmtDataset(load_path=train_data_path)
val_dataset = AmtDataset(load_path=val_data_path)
logger.info(
Expand All @@ -220,7 +221,7 @@ def get_dataloaders(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True, # Maybe remove
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
Expand Down
2 changes: 1 addition & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"chunk_len": 30
},
"data": {
"stride_factor": 1,
"stride_factor": 3,
"max_seq_len": 4096
}
}
6 changes: 3 additions & 3 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int):

DELTA_MS = 5000
tokenizer = AmtTokenizer()
midi_dict = MidiDict.from_midi("tests/test_data/bach.mid")
midi_dict = MidiDict.from_midi("tests/test_data/maestro2.mid")
__end_ms = midi_dict.note_msgs[-1]["data"]["end"]

for idx, __start_ms in enumerate(range(0, __end_ms, DELTA_MS)):
Expand Down Expand Up @@ -86,13 +86,13 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int):
tokenized_seq, DELTA_MS
)
_mid = _midi_dict.to_midi()
_mid.save(f"tests/test_results/bach_orig.mid")
_mid.save(f"tests/test_results/maestro2_orig.mid")

_midi_dict = tokenizer._detokenize_midi_dict(
aug_tokenized_seq, DELTA_MS
)
_mid = _midi_dict.to_midi()
_mid.save(f"tests/test_results/bach_aug.mid")
_mid.save(f"tests/test_results/maestro2_aug.mid")


if __name__ == "__main__":
Expand Down

0 comments on commit 01845ee

Please sign in to comment.