Skip to content

Commit

Permalink
Change overwrite logic (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Sep 29, 2024
1 parent 4120f57 commit 5a904e6
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,17 +783,17 @@ def get_save_path(
file_path: str,
input_dir: str | None,
save_dir: str,
idx: int | str = "",
idx_str: int | str = "",
):
if input_dir is None:
save_path = os.path.join(
save_dir,
os.path.splitext(os.path.basename(file_path))[0] + f"{idx}.mid",
os.path.splitext(os.path.basename(file_path))[0] + f"{idx_str}.mid",
)
else:
input_rel_path = os.path.relpath(file_path, input_dir)
save_path = os.path.join(
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx}.mid"
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx_str}.mid"
)
if not os.path.isdir(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Expand All @@ -810,7 +810,7 @@ def process_file(
save_dir: str,
input_dir: str,
logger: logging.Logger,
segments: List[Tuple[int, int]] | None = None,
segments: List[Tuple[int, Tuple[int, int]]] | None = None,
):
def _save_seq(_seq: List, _save_path: str):
if os.path.exists(_save_path):
Expand Down Expand Up @@ -852,12 +852,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):

pid = threading.get_ident()
if segments is None:
segments = [None]
# process_file and get_wav_segments will interpret segment=None as
# processing the entire file
segments = [(None, None)]

if len(segments) == 0:
logger.info(f"No segments to transcribe, skipping file: {file_path}")

for idx, segment in enumerate(segments):
for idx, segment in segments:
idx_str = f"_{idx}" if idx is not None else ""
save_path = get_save_path(file_path, input_dir, save_dir, idx_str)

try:
seq = transcribe_file(
file_path,
Expand All @@ -876,15 +881,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
logger.info(f"Removed {res_rmv_cnt} from result queue")
continue

logger.info(f"Finished file: {file_path} (segment: {idx})")
logger.info(
f"Finished file: {file_path} (segment: {idx if idx is not None else 'full'})"
)
if len(seq) < 500:
logger.info(f"Skipping seq - too short (segment {idx})")
logger.info(
f"Skipping seq - too short (segment {idx if idx is not None else 'full'})"
)
else:
logger.debug(
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})"
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx if idx is not None else 'full'})"
)
idx = f"_{idx}" if segment is not None else ""
save_path = get_save_path(file_path, input_dir, save_dir, idx)
_save_seq(seq, save_path)

logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
Expand Down Expand Up @@ -997,20 +1004,28 @@ def batch_transcribe(
files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True
)
for file_to_process in files_to_process:
# Only add to file_queue if transcription MIDI file doesn't exist
if (
os.path.isfile(
if "segments" in file_to_process:
# Process files with segments
unsaved_segments = []
for idx, segment in enumerate(file_to_process["segments"]):
segment_save_path = get_save_path(
file_to_process["path"],
input_dir,
save_dir,
idx_str=f"_{idx}",
)
if not os.path.isfile(segment_save_path):
unsaved_segments.append((idx, segment))

if unsaved_segments:
file_to_process["segments"] = unsaved_segments
file_queue.put(file_to_process)
else:
# Process files without segments (whole file)
if not os.path.isfile(
get_save_path(file_to_process["path"], input_dir, save_dir)
)
is False
) and os.path.isfile(
get_save_path(
file_to_process["path"], input_dir, save_dir, idx="_0"
)
) is False:
file_queue.put(file_to_process)
elif len(files_to_process) == 1:
file_queue.put(file_to_process)
):
file_queue.put(file_to_process)

logger.info(
f"Files to process: {file_queue.qsize()}/{len(files_to_process)}"
Expand All @@ -1026,7 +1041,7 @@ def batch_transcribe(
file_queue.qsize(),
)
num_processes_per_worker = min(
3 * (batch_size // num_workers), file_queue.qsize() // num_workers
5 * (batch_size // num_workers), file_queue.qsize() // num_workers
)

mp_manager = Manager()
Expand Down

0 comments on commit 5a904e6

Please sign in to comment.