Skip to content

Commit

Permalink
Fix prefill CUDA memory leakage (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Jul 10, 2024
1 parent effaa99 commit 525b53f
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 57 deletions.
103 changes: 60 additions & 43 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _setup_logger(name: str | None = None):

logger.propagate = False
logger.setLevel(logging.DEBUG)
# Adjust the formatter to include the name before the PID if provided
formatter = logging.Formatter(
f"[%(asctime)s] {logger_name}%(process)d: [%(levelname)s] %(message)s",
)
Expand Down Expand Up @@ -169,7 +168,7 @@ def prefill(
x_input_pos: torch.Tensor,
xa_input_pos: torch.Tensor,
):
# This is the same as decode_token, however we don't compile the prefill
# This is the same as decode_token and is separate for compilation reasons
logits = model.decoder.forward(
x=x,
xa=xa,
Expand Down Expand Up @@ -216,18 +215,20 @@ def process_segments(
)
):
# for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1):
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
if idx == min_prefix_len:
logits, next_tok_ids = decode_token(
model,
x=seq[:, :idx],
xa=audio_features,
x_input_pos=torch.arange(0, idx, device=seq.device),
xa_input_pos=torch.arange(
0, audio_features.shape[1], device=seq.device
),
)
else:
if idx == min_prefix_len:
logits, next_tok_ids = prefill(
model,
x=seq[:, :idx],
xa=audio_features,
x_input_pos=torch.arange(0, idx, device=seq.device),
xa_input_pos=torch.arange(
0, audio_features.shape[1], device=seq.device
),
)
else:
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH
):
logits, next_tok_ids = decode_token(
model,
x=seq[:, idx - 1 : idx],
Expand All @@ -241,7 +242,7 @@ def process_segments(
)
assert not torch.isnan(logits).any(), "NaN seen in logits"

logits[:, 389] *= 1.05
logits[:, 389] *= 1.05 # Increase pedal-off msg logits
next_tok_ids = torch.argmax(logits, dim=-1)

next_tok_ids = recalculate_tok_ids(
Expand Down Expand Up @@ -278,18 +279,18 @@ def gpu_manager(
result_queue: Queue,
model: AmtEncoderDecoder,
batch_size: int,
compile: bool = False,
compile_mode: str | bool = False,
gpu_id: int | None = None,
):
if gpu_id:
if gpu_id is not None:
logger = _setup_logger(name=f"GPU-{gpu_id}")
else:
logger = _setup_logger(name=f"GPU")

logger.info("Started GPU manager")

if gpu_id is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
torch.cuda.set_device(gpu_id)

model.decoder.setup_cache(
batch_size=batch_size,
Expand All @@ -298,16 +299,11 @@ def gpu_manager(
)
model.cuda()
model.eval()
if compile is True:
global decode_token, recalculate_tok_ids
if batch_size == 1:
recalculate_tok_ids = torch.compile(
recalculate_tok_ids, mode="max-autotune-no-cudagraphs"
)
if compile_mode is not False:
global decode_token
decode_token = torch.compile(
decode_token,
mode="reduce-overhead",
# mode="max-autotune",
mode=compile_mode,
fullgraph=True,
)

Expand Down Expand Up @@ -695,7 +691,7 @@ def transcribe_file(

def get_save_path(
file_path: str,
input_dir: str,
input_dir: str | None,
save_dir: str,
idx: int | str = "",
):
Expand Down Expand Up @@ -848,7 +844,6 @@ def worker(
logger.info(f"File worker terminated")


# Needs to test this for multi-gpu
def batch_transcribe(
file_paths: list,
model: AmtEncoderDecoder,
Expand All @@ -857,9 +852,15 @@ def batch_transcribe(
input_dir: str | None = None,
gpu_ids: int | None = None,
quantize: bool = False,
compile: bool = False,
compile_mode: str | bool = False,
):
assert os.name == "posix", "UNIX/LINUX is the only supported OS"
assert compile_mode in {
"reduce-overhead",
"max-autotune",
False,
}, "Invalid value for compile_mode"

torch.multiprocessing.set_start_method("spawn")
num_gpus = len(gpu_ids) if gpu_ids is not None else 1
logger = _setup_logger()
Expand Down Expand Up @@ -933,39 +934,55 @@ def batch_transcribe(
result_queue,
model,
batch_size,
compile,
compile_mode,
gpu_id,
),
)
for gpu_id in gpu_ids
for gpu_id in range(len(gpu_ids))
]
for p in gpu_manager_processes:
child_pids.append(gpu_manager_processes.pid)
p.start()
watchdog_process = multiprocessing.Process(
target=watchdog, args=(gpu_batch_manager_process.pid, child_pids)
target=watchdog, args=(os.getpid(), child_pids)
)
watchdog_process.start()
else:
gpu_manager_processes = None
_gpu_manager_process = multiprocessing.Process(
target=gpu_manager,
args=(
gpu_batch_queue,
result_queue,
model,
batch_size,
compile_mode,
),
)
child_pids.append(_gpu_manager_process.pid)
_gpu_manager_process.start()
gpu_manager_processes = [_gpu_manager_process]

watchdog_process = multiprocessing.Process(
target=watchdog, args=(os.getpid(), child_pids)
)
watchdog_process.start()
gpu_manager(
gpu_batch_queue,
result_queue,
model,
batch_size,
compile,
)

for p in worker_processes:
p.join()

if gpu_manager_processes is not None:
for p in gpu_manager_processes:
p.terminate()
p.join()

for p in worker_processes:
p.terminate()
p.join()
file_queue.close()
file_queue.join_thread()
gpu_task_queue.close()
gpu_task_queue.join_thread()
gpu_batch_queue.close()
gpu_batch_queue.join_thread()
result_queue.close()
result_queue.join_thread()

gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()
Expand Down
36 changes: 23 additions & 13 deletions amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def _add_transcribe_args(subparser):
action="store_true",
default=False,
)
subparser.add_argument(
"-max_autotune",
help="use mode=max_autotune when compiling",
action="store_true",
default=False,
)
subparser.add_argument("-bs", help="batch size", type=int, default=16)


Expand Down Expand Up @@ -341,16 +347,16 @@ def build_maestro(


def transcribe(
model_name,
checkpoint_path,
save_dir,
load_path=None,
load_dir=None,
maestro=False,
batch_size=16,
multi_gpu=False,
quantize=False,
compile=False,
model_name: str,
checkpoint_path: str,
save_dir: str,
load_path: str | None = None,
load_dir: str | None = None,
maestro: bool = False,
batch_size: int = 8,
multi_gpu: bool = False,
quantize: bool = False,
compile_mode: str | bool = False,
):
"""
Transcribe audio files to midi using the given model and checkpoint.
Expand Down Expand Up @@ -449,7 +455,7 @@ def transcribe(
input_dir=load_dir,
gpu_ids=gpu_ids,
quantize=quantize,
compile=compile,
compile_mode=compile_mode,
)

else:
Expand All @@ -460,7 +466,7 @@ def transcribe(
batch_size=batch_size,
input_dir=load_dir,
quantize=quantize,
compile=compile,
compile_mode=compile_mode,
)


Expand Down Expand Up @@ -528,7 +534,11 @@ def main():
batch_size=args.bs,
multi_gpu=args.multi_gpu,
quantize=args.q8,
compile=args.compile,
compile_mode=(
"max-autotune"
if args.compile and args.max_autotune
else "reduce-overhead" if args.compile else False
),
)
else:
print("Unrecognized command")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
aria @ git+https://github.com/EleutherAI/aria.git
torch >= 2.2
torch >= 2.3
torchaudio
accelerate
psutil
Expand Down

0 comments on commit 525b53f

Please sign in to comment.