Skip to content

Commit

Permalink
Small adjustment to inference script (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Jul 5, 2024
1 parent e67ac28 commit effaa99
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
33 changes: 27 additions & 6 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,26 @@ def decode_token(
return logits, next_tok_ids


@torch.no_grad()
def prefill(
model: AmtEncoderDecoder,
x: torch.Tensor,
xa: torch.Tensor,
x_input_pos: torch.Tensor,
xa_input_pos: torch.Tensor,
):
# This is the same as decode_token, however we don't compile the prefill
logits = model.decoder.forward(
x=x,
xa=xa,
x_input_pos=x_input_pos,
xa_input_pos=xa_input_pos,
)[:, -1]
next_tok_ids = torch.argmax(logits, dim=-1)

return logits, next_tok_ids


@optional_bf16_autocast
@torch.no_grad()
def process_segments(
Expand Down Expand Up @@ -196,9 +216,7 @@ def process_segments(
)
):
# for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1):
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
):
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
if idx == min_prefix_len:
logits, next_tok_ids = decode_token(
model,
Expand Down Expand Up @@ -442,7 +460,7 @@ def _truncate_seq(
return res


# This is a sloppy implementation
# TODO: Add detection for pedal messages which occur before notes are played
def process_silent_intervals(
seq: list, intervals: list, tokenizer: AmtTokenizer
):
Expand Down Expand Up @@ -621,7 +639,7 @@ def transcribe_file(
seq, intervals=silent_intervals, tokenizer=tokenizer
)

if len(seq_adj) < len(seq) - 3:
if len(seq_adj) < len(seq) - 5:
logger.info(
f"Removed tokens ({len(seq)} -> {len(seq_adj)}) "
f"in segment {idx} according to silence in intervals: "
Expand Down Expand Up @@ -718,7 +736,8 @@ def _save_seq(_seq: list, _save_path: str):

try:
mid_dict = tokenizer._detokenize_midi_dict(
tokenized_seq=_seq, len_ms=last_onset
tokenized_seq=_seq,
len_ms=last_onset,
)
mid = mid_dict.to_midi()
mid.save(_save_path)
Expand Down Expand Up @@ -779,6 +798,8 @@ def watchdog(main_gpu_pid: int, child_pids: list):
except ProcessLookupError:
pass

return

time.sleep(1)


Expand Down
9 changes: 8 additions & 1 deletion amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

from csv import DictReader

# V2 ideas:
# - When we retrain, perhaps provide many examples of non piano-audio, matched with a
# special token which denotes a non-piano audio segment.
# - We could additionally occasionally splice a piano segment with non piano audio
# and task the model with detecting this using the tokenizer.
# - Retrain with much larger synth dataset


# TODO: Implement a way of inferring the tokenizer name automatically
def _add_maestro_args(subparser):
Expand Down Expand Up @@ -412,7 +419,7 @@ def transcribe(
os.path.join(load_dir, "**/*.wav"), recursive=True
)
found_mp3 = glob.glob(
os.path.join(load_dir, "**/*.mp3"), recursive=True
os.path.join(load_dir, "**/*.mp[34]"), recursive=True
)
print(f"Found {len(found_mp3)} mp3 and {len(found_wav)} wav files")
file_paths = found_mp3 + found_wav
Expand Down
4 changes: 3 additions & 1 deletion scripts/eval/dtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def calc_score(_midi_cqt, _audio_cqt):
def get_matched_files(audio_dir: str, mid_dir: str):
# We assume that the files have the same path relative to their directory
res = []
wav_paths = glob.glob(os.path.join(audio_dir, "**/*.mp3"), recursive=True)
wav_paths = glob.glob(
os.path.join(audio_dir, "**/*.mp[34]"), recursive=True
)
print(f"found {len(wav_paths)} mp3 files")

for wav_path in wav_paths:
Expand Down

0 comments on commit effaa99

Please sign in to comment.