diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 52b5bd9..f18705f 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -11,8 +11,10 @@ import torch._dynamo.config import torch._inductor.config import numpy as np +import concurrent from torch.multiprocessing import Queue +from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm from functools import wraps from torch.cuda import is_bf16_supported @@ -313,7 +315,7 @@ def gpu_manager( try: while True: try: - batch = gpu_batch_queue.get(timeout=30) + batch = gpu_batch_queue.get(timeout=60) except Exception as e: logger.info(f"GPU timed out waiting for batch") break @@ -380,12 +382,13 @@ def gpu_batch_manager( tasks = [] while True: try: - task, pid = gpu_task_queue.get(timeout=0.05) + task, pid = gpu_task_queue.get(timeout=0.1) except Exception as e: pass else: tasks.append((task, pid)) - continue + if gpu_batch_queue.empty() is False: + continue # No tasks in queue -> check gpu batch queue if gpu_batch_queue.empty() is False: @@ -784,9 +787,9 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): logger.info(f"{file_queue.qsize()} file(s) remaining in queue") -def watchdog(main_gpu_pid: int, child_pids: list): +def watchdog(main_pids: list, child_pids: list): while True: - if not os.path.exists(f"/proc/{main_gpu_pid}"): + if not all(os.path.exists(f"/proc/{pid}") for pid in main_pids): print("Cleaning up children...") for pid in child_pids: try: @@ -805,43 +808,44 @@ def worker( result_queue: Queue, save_dir: str, input_dir: str | None = None, - tasks_per_worker: int = 1, + tasks_per_worker: int = 5, ): logger = _setup_logger(name="F") tokenizer = AmtTokenizer() - threads = [] - try: - while not file_queue.empty() or any(t.is_alive() for t in threads): - while len(threads) < tasks_per_worker and not file_queue.empty(): - logging.info("Starting worker") - file_path = file_queue.get() - t = threading.Thread( - target=process_file, - args=( - file_path, - file_queue, - gpu_task_queue, - result_queue, - tokenizer, - save_dir, - input_dir, - logger, - ), - ) - t.start() - threads.append(t) - - threads = [t for t in threads if t.is_alive()] - time.sleep(0.1) + def process_file_wrapper(): + while True: + try: + file_path = file_queue.get(timeout=5) + except Exception as e: + if file_queue.empty(): + logger.info("File queue empty") + break + else: + continue - for t in threads: - t.join() + process_file( + file_path, + file_queue, + gpu_task_queue, + result_queue, + tokenizer, + save_dir, + input_dir, + logger, + ) + try: + with ThreadPoolExecutor(max_workers=tasks_per_worker) as executor: + futures = [ + executor.submit(process_file_wrapper) + for _ in range(tasks_per_worker) + ] + concurrent.futures.wait(futures) except Exception as e: logger.error(f"File worker failed with exception: {e}") finally: - logger.info(f"File worker terminated") + logger.info("File worker terminated") def batch_transcribe( @@ -941,7 +945,7 @@ def batch_transcribe( for gpu_id in range(len(gpu_ids)) ] for p in gpu_manager_processes: - child_pids.append(gpu_manager_processes.pid) + child_pids.append(p.pid) p.start() watchdog_process = multiprocessing.Process( target=watchdog, args=(os.getpid(), child_pids) @@ -963,7 +967,15 @@ def batch_transcribe( gpu_manager_processes = [_gpu_manager_process] watchdog_process = multiprocessing.Process( - target=watchdog, args=(os.getpid(), child_pids) + target=watchdog, + args=( + [ + os.getpid(), + gpu_batch_manager_process.pid, + _gpu_manager_process.pid, + ], + child_pids, + ), ) watchdog_process.start()