Skip to content

Commit

Permalink
eliminate parallel training
Browse files Browse the repository at this point in the history
  • Loading branch information
LWprogramming committed Jul 18, 2023
1 parent 4639dd2 commit cdb842c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 77 deletions.
101 changes: 25 additions & 76 deletions audiolm_pytorch_demo_laion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
# Checkpoint loading. Expect format to be something like /path/to/semantic.transformer.20000.pt
parser = argparse.ArgumentParser()
parser.add_argument('--slurm_job_id', type=int, help='slurm job id, used for creating results folders', required=True)
# parallel vs non-parallel training
parser.add_argument('--no_parallel_training', dest='parallel_training', help="disable parallel training, forcing transformers to train in sequence.", action='store_false')
parser.add_argument('--parallel_training', dest='parallel_training', help="enable parallel training, forcing transformers to train a little bit before handing off to the next transformer. Good for seeing how results progressively get better.", action='store_true')
parser.add_argument('--train_or_eval', type=str, help="decide which transformer to train (pick from choices)", choices=["train_semantic", "train_coarse", "train_fine", "evaluate"], required=True)
parser.add_argument('--run_mode',
type=str,
help='run configuration (pick from choices). Sets dataset_folder, num_train_steps, save_every, batch_size, and grad_accum_every',
Expand All @@ -55,9 +53,10 @@
"test_long_sample"],
default=None,
required=True)
parser.set_defaults(parallel_training=False)
parser.add_argument("--with_profiling", type=bool, default=False, nargs='?', const=True)
args = parser.parse_args()
if args.with_profiling:
raise NotImplementedError("Profiling is not implemented yet. see train() function below")
results_folder_suffix = str(args.slurm_job_id)
print("parsed args")

Expand Down Expand Up @@ -351,77 +350,27 @@ def get_sample(wav2vec, codec, semantic_transformer, coarse_transformer, fine_tr
sample_rate = 24000
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

def train_everything(profiler=None):
def train(profiler=None):
# pass in profiler as an optional argument. If we're not doing profiling, then nothing happens.
# If we're doing profiling (implemented only for training in parallel because it'd require more work to do it in series as well since we need to manually run profiler.step() to work with the profiler schedule), this just profiles the *second* save_every steps segment
# see: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-long-running-jobs for more info on profiler.step() and schedules. We use the second save_every steps instead of the first in case there's any weird overhead to setting things up.
if args.parallel_training:
print("training in parallel")
def train_models(steps_to_train):
# Due to preemption in Slurm, we might see a checkpoint for one model but not the others. Since our parallel training trains models in a round-robin fashion, we figure out the first trainer for which we *don't* have a checkpoint and train that one.
# This is a bit hacky, but it works.
print(
f"deciding which model to train. coarse_trainer.steps: {coarse_trainer.steps}, semantic_trainer.steps: {semantic_trainer.steps}, fine_trainer.steps: {fine_trainer.steps}. devices are {coarse_trainer.device}, {semantic_trainer.device}, {fine_trainer.device} for coarse, semantic, and fine respectively")

if coarse_trainer.steps == 0:
trainer_index = 0
elif semantic_trainer.steps < coarse_trainer.steps:
trainer_index = 1
elif semantic_trainer.steps == coarse_trainer.steps and fine_trainer.steps < coarse_trainer.steps:
trainer_index = 2
else:
trainer_index = 0 # all the models have an equal train step-- start from the beginning

trainers = [coarse_trainer, semantic_trainer, fine_trainer]

print(f"training model {trainer_index} on device {trainers[trainer_index].device}")
for i in range(trainer_index, len(trainers)):
start_time = datetime.datetime.now()
for _ in range(steps_to_train + 1):
# train + 1 to avoid off-by-one errors-- basically we want to train for steps_to_train steps, but we
# also want to train for one more step so that the modular arithmetic triggers and we actually save
# the checkpoint
trainers[i].train_step()
end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
print(f"Device {trainers[i].device}: Time taken for {steps_to_train} steps of {trainers[i].__class__.__name__}: {elapsed_time}")

for step in range(0, num_train_steps, save_every):
train_models(save_every)
if profiler is not None:
profiler.step()
# TODO: do we need to wait for everyone here or can we just get away with waiting on main only?
# I don't think so because we just saved checkpoint right? going to try without
# semantic_trainer.accelerator.wait_for_everyone()
# coarse_trainer.accelerator.wait_for_everyone()
# fine_trainer.accelerator.wait_for_everyone()
# print(f"all clear from semantic trainer on device {semantic_trainer.device}")
if semantic_trainer.accelerator.is_main_process:
# they should all be main process if this is called right?
assert coarse_trainer.accelerator.is_main_process and fine_trainer.accelerator.is_main_process
print("generating now...")
# print(f"semantic_trainer on device {semantic_trainer.device}") # device 0
get_sample(wav2vec, codec, semantic_transformer, coarse_transformer, fine_transformer, step)
else:
assert not(coarse_trainer.accelerator.is_main_process or fine_trainer.accelerator.is_main_process)
print(f"not generating on device {semantic_trainer.device}")
print("finished generation")
print(f"semantic_trainer on device {semantic_trainer.device} arrived here")
semantic_trainer.accelerator.wait_for_everyone()
print(f"coarse_trainer on device {coarse_trainer.device} arrived here")
coarse_trainer.accelerator.wait_for_everyone()
print(f"fine_trainer on device {fine_trainer.device} arrived here")
fine_trainer.accelerator.wait_for_everyone()
print(f"all waits complete.")
# TODO: this code needs to happen SOMEWHERE if we're doing profiling.
# TODO: see: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-long-running-jobs for more info on profiler.step() and schedules. We use the second save_every steps instead of the first in case there's any weird overhead to setting things up.
# if profiler is not None:
# profiler.step()
if args.train_or_eval == "evaluate":
step = semantic_trainer.steps.item()
assert semantic_trainer.steps.item() == coarse_trainer.steps.item() and coarse_trainer.steps.item() == fine_trainer.steps.item(), "all three trainers should have the same number of steps when fully trained"
get_sample(wav2vec, codec, semantic_transformer, coarse_transformer, fine_transformer, step)
return
elif args.train_or_eval == "train_semantic":
trainer = semantic_trainer
elif args.train_or_eval == "train_coarse":
trainer = coarse_trainer
elif args.train_or_eval == "train_fine":
trainer = fine_trainer
else:
# non parallel training
assert profiler is None, "profiling not implemented for training NOT in parallel"
print("not training in parallel")
semantic_trainer.train()
coarse_trainer.train()
fine_trainer.train()
raise NotImplementedError("haven't implemented checking is main process for non-parallel training yet")
get_sample(wav2vec, codec, semantic_transformer, coarse_transformer, fine_transformer, num_train_steps)
raise AssertionError(f"train_or_eval argument {args.train_or_eval} not recognized, should be unreachable")
print(f"training using {args.train_or_eval} trainer")
trainer.train()

def trace_handler(prof):
profile_log = f"{prefix}/profiler_{args.slurm_job_id}.txt"
Expand Down Expand Up @@ -473,11 +422,11 @@ def trace_handler(prof):
active=1,
repeat=1),
on_trace_ready=trace_handler) as profiler:
with record_function("train_everything"):
train_everything(profiler=profiler)
with record_function("train"):
train(profiler=profiler)
else:
print("training without profiling")
train_everything()
train()

# with record_function("model_inference"):
# generated_wav = audiolm(batch_size = 1)
Expand Down
9 changes: 8 additions & 1 deletion sbatch_job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,11 @@ source venv/bin/activate # in case this hasn't already been done
echo "run mode: " $RUN_MODE
echo "with profiling: " $WITH_PROFILING
echo "slurm job id to actually use: " $OVERRIDABLE_SLURM_JOB_ID
accelerate launch audiolm_pytorch_demo_laion_$OVERRIDABLE_SLURM_JOB_ID.py --run_mode $RUN_MODE $WITH_PROFILING --slurm_job_id $OVERRIDABLE_SLURM_JOB_ID --parallel_training

# Transformers need to be trained separately, see: https://github.com/lucidrains/audiolm-pytorch/issues/209#issuecomment-1640777646
accelerate launch audiolm_pytorch_demo_laion_$OVERRIDABLE_SLURM_JOB_ID.py --run_mode $RUN_MODE $WITH_PROFILING --slurm_job_id $OVERRIDABLE_SLURM_JOB_ID --train_or_eval train_semantic
accelerate launch audiolm_pytorch_demo_laion_$OVERRIDABLE_SLURM_JOB_ID.py --run_mode $RUN_MODE $WITH_PROFILING --slurm_job_id $OVERRIDABLE_SLURM_JOB_ID --train_or_eval train_coarse
accelerate launch audiolm_pytorch_demo_laion_$OVERRIDABLE_SLURM_JOB_ID.py --run_mode $RUN_MODE $WITH_PROFILING --slurm_job_id $OVERRIDABLE_SLURM_JOB_ID --train_or_eval train_fine

# evaluate separately
accelerate launch audiolm_pytorch_demo_laion_$OVERRIDABLE_SLURM_JOB_ID.py --run_mode $RUN_MODE $WITH_PROFILING --slurm_job_id $OVERRIDABLE_SLURM_JOB_ID --train_or_eval evaluate

0 comments on commit cdb842c

Please sign in to comment.