Skip to content

Commit

Permalink
Adjust scripts, readme, and config (#49)
Browse files Browse the repository at this point in the history
* fix to inference

* change lr

* reorganize readme scripts and config
  • Loading branch information
loubbrad authored Aug 12, 2024
1 parent d77292a commit 352de4a
Show file tree
Hide file tree
Showing 16 changed files with 88 additions and 113 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Download the preliminary model weights:
Piano (v1)

```
wget https://storage.googleapis.com/aria-checkpoints/amt/piano-medium-stacked-1.0.safetensors
wget https://storage.googleapis.com/aria-checkpoints/amt/piano-medium-double-1.0.safetensors
```

## Usage
Expand All @@ -32,13 +32,12 @@ You can then transcribe using the cli:

```
aria-amt transcribe \
medium-stacked \
medium-double \
<path-to-checkpoint> \
-load_path <path-to-audio> \
-save_dir <path-to-save-dir> \
-bs 1 \
-compile \
-q8
-compile
```

If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8` flag) further speeds up inference when the `-compile` flag is also used.
Expand Down
10 changes: 5 additions & 5 deletions amt/inference/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):

def replace_linear_weight_only_int8_per_channel(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if isinstance(child, nn.Linear) and name != "output":
if child.bias is not None:
setattr(
module,
Expand All @@ -71,13 +71,13 @@ def __init__(self, mod: torch.nn.Module):
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
for name, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear) and name != "output":
int8_weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(), -128, 127, torch.int8
)
cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to(
cur_state_dict[f"{name}.weight"] = int8_weight.to("cpu")
cur_state_dict[f"{name}.scales"] = scales.to(
mod.weight.dtype
).to("cpu")

Expand Down
92 changes: 63 additions & 29 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR


# TODO: Implement continuous batching in a torch.compile friendly way


def _setup_logger(name: str | None = None):
logger_name = f"[{name}] " if name else ""
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -187,6 +190,26 @@ def prefill(
return logits, next_tok_ids


# This is not used anywhere but may come in handy one days
def calculate_input_pos(prefix_lens: torch.Tensor, prefill: bool):
# Given prefix lens e.g. [67, 2, 9], generates the input positions,
# truncate to the left with -1
max_pos = torch.max(prefix_lens)
pos_idxs = torch.stack(
[
torch.cat(
[
torch.full((max_pos - pos,), -1, dtype=torch.long),
torch.arange(pos),
]
)
for pos in prefix_lens
]
)

return pos_idxs


@optional_bf16_autocast
@torch.no_grad()
def process_segments(
Expand Down Expand Up @@ -389,8 +412,8 @@ def _find_min_diff_batch(tasks: List, batch_size: int):
# - For some reason copying gpu_waiting_dict is not working properly and is
# leading to race conditions. I've implemented a lock to stop it.
# - The size of gpu_batch_queue decreases before the code for deleting the
# corresponding entry in gpu_waiting_dict get processed. Adding a short sleep
# is a workaround
# corresponding entry in gpu_waiting_dict gets processed. Adding a short
# sleep is a workaround
def gpu_batch_manager(
gpu_task_queue: Queue,
gpu_batch_queue: Queue,
Expand Down Expand Up @@ -847,16 +870,21 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")


def cleanup_processes(child_pids: List):
for pid in child_pids:
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError as e:
pass
except Exception as e:
print(f"Failed to kill child process: {e}")


def watchdog(main_pids: List, child_pids: List):
while True:
if not all(os.path.exists(f"/proc/{pid}") for pid in main_pids):
print("Watchdog cleaning up children...")
for pid in child_pids:
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
pass

cleanup_processes(child_pids=child_pids)
print("Watchdog exit.")
return

Expand Down Expand Up @@ -958,10 +986,10 @@ def batch_transcribe(

if num_workers is None:
num_workers = min(
min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus),
min(batch_size, multiprocessing.cpu_count() - num_gpus),
file_queue.qsize(),
)
num_processes_per_worker = min(5, file_queue.qsize() // num_workers)
num_processes_per_worker = min(10, file_queue.qsize() // num_workers)

mp_manager = Manager()
gpu_waiting_dict = mp_manager.dict()
Expand Down Expand Up @@ -1070,27 +1098,33 @@ def batch_transcribe(
)
watchdog_process.start()

for p in worker_processes:
p.join()

if gpu_manager_processes is not None:
for p in gpu_manager_processes:
p.terminate()
try:
for p in worker_processes:
p.join()

file_queue.close()
file_queue.join_thread()
gpu_task_queue.close()
gpu_task_queue.join_thread()
gpu_batch_queue.close()
gpu_batch_queue.join_thread()
result_queue.close()
result_queue.join_thread()

gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()
watchdog_process.terminate()
watchdog_process.join()
if gpu_manager_processes is not None:
for p in gpu_manager_processes:
p.terminate()
p.join()

except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt")
logger.info("Cleaning up child processes")
cleanup_processes(child_pids=child_pids)
logger.info("Complete")
finally:
gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()
watchdog_process.terminate()
watchdog_process.join()
file_queue.close()
file_queue.join_thread()
gpu_task_queue.close()
gpu_task_queue.join_thread()
gpu_batch_queue.close()
gpu_batch_queue.join_thread()
result_queue.close()
result_queue.join_thread()

time_taken_s = int(time.time() - start_time)
logger.info(
Expand Down
3 changes: 0 additions & 3 deletions amt/mir.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import glob
from tqdm.auto import tqdm
import pretty_midi
import numpy as np
import mir_eval
import json
import os

from aria.data.midi import MidiDict, get_duration_ms

pretty_midi.pretty_midi.MAX_TICK = 1e10


def midi_to_intervals_and_pitches(midi_file_path):
mid_dict = MidiDict.from_midi(midi_file_path)
Expand Down
2 changes: 1 addition & 1 deletion amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_pretrain_optim(
num_epochs: int,
steps_per_epoch: int,
):
LR = 5e-4
LR = 3e-4
END_RATIO = 0.1
WARMUP_STEPS = 1000

Expand Down
File renamed without changes.
File renamed without changes.
11 changes: 0 additions & 11 deletions config/models/small-final.json

This file was deleted.

3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
black
matplotlib
black
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
58 changes: 0 additions & 58 deletions scripts/eval/adjust.py

This file was deleted.

15 changes: 15 additions & 0 deletions scripts/transcribe-youtube.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

mkdir -p data models
[ -f data/audio.webm ] && rm data/audio.webm
[ -f data/audio.mp3 ] && rm data/audio.mp3

yt-dlp -x --audio-format mp3 --extract-audio --audio-quality 0 --no-playlist -o data/audio.mp3 $1

aria-amt transcribe \
medium-double \
models/piano-medium-double-1.0.safetensors \
-load_path data/audio.mp3 \
-save_dir data \
-compile \
-bs 1

0 comments on commit 352de4a

Please sign in to comment.