From 130eb633dcceaeb3d5554601928b06f421198a63 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Date: Fri, 16 Jun 2023 14:26:14 +0200 Subject: [PATCH] Don't use hparams internal value --- vocos/experiment.py | 55 ++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/vocos/experiment.py b/vocos/experiment.py index 22857d9..b51aa6a 100644 --- a/vocos/experiment.py +++ b/vocos/experiment.py @@ -67,6 +67,15 @@ def __init__( self.train_discriminator = False self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff + self.initial_learning_rate = initial_learning_rate + self.num_warmup_steps = num_warmup_steps + self.mrd_loss_coeff = mrd_loss_coeff + self.pretrain_mel_steps = pretrain_mel_steps + self.decay_mel_coeff = decay_mel_coeff + self.evaluate_utmos = evaluate_utmos + self.evaluate_pesq = evaluate_pesq + self.evaluate_periodicty = evaluate_periodicty + def configure_optimizers(self): disc_params = [ {"params": self.multiperioddisc.parameters()}, @@ -78,15 +87,15 @@ def configure_optimizers(self): {"params": self.head.parameters()}, ] - opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate) - opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate) + opt_disc = torch.optim.AdamW(disc_params, lr=self.initial_learning_rate) + opt_gen = torch.optim.AdamW(gen_params, lr=self.initial_learning_rate) max_steps = self.trainer.max_steps // 2 # Max steps per optimizer scheduler_disc = transformers.get_cosine_schedule_with_warmup( - opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + opt_disc, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps ) scheduler_gen = transformers.get_cosine_schedule_with_warmup( - opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + opt_gen, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps ) return ( @@ -118,7 +127,7 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): ) loss_mp /= len(loss_mp_real) loss_mrd /= len(loss_mrd_real) - loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss = loss_mp + self.mrd_loss_coeff * loss_mrd self.log("discriminator/total", loss, prog_bar=True) self.log("discriminator/multi_period_loss", loss_mp) @@ -152,9 +161,9 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): mel_loss = self.melspec_loss(audio_hat, audio_input) loss = ( loss_gen_mp - + self.hparams.mrd_loss_coeff * loss_gen_mrd + + self.mrd_loss_coeff * loss_gen_mrd + loss_fm_mp - + self.hparams.mrd_loss_coeff * loss_fm_mrd + + self.mrd_loss_coeff * loss_fm_mrd + self.mel_loss_coeff * mel_loss ) @@ -164,10 +173,10 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): if self.global_step % 1000 == 0 and self.global_rank == 0: self.logger.experiment.add_audio( - "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate + "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.sample_rate ) self.logger.experiment.add_audio( - "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate + "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.sample_rate ) with torch.no_grad(): mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) @@ -188,7 +197,7 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): return loss def on_validation_epoch_start(self): - if self.hparams.evaluate_utmos: + if self.evaluate_utmos: from metrics.UTMOS import UTMOSScore if not hasattr(self, "utmos_model"): @@ -198,22 +207,22 @@ def validation_step(self, batch, batch_idx, **kwargs): audio_input = batch audio_hat = self(audio_input, **kwargs) - audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) - audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) + audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.sample_rate, new_freq=16000) + audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.sample_rate, new_freq=16000) - if self.hparams.evaluate_periodicty: + if self.evaluate_periodicty: from metrics.periodicity import calculate_periodicity_metrics periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) else: periodicity_loss = pitch_loss = f1_score = 0 - if self.hparams.evaluate_utmos: + if self.evaluate_utmos: utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() else: utmos_score = torch.zeros(1, device=self.device) - if self.hparams.evaluate_pesq: + if self.evaluate_pesq: from pesq import pesq pesq_score = 0 @@ -243,10 +252,10 @@ def validation_epoch_end(self, outputs): if self.global_rank == 0: *_, audio_in, audio_pred = outputs[0].values() self.logger.experiment.add_audio( - "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + "val_in", audio_in.data.cpu().numpy(), self.global_step, self.sample_rate ) self.logger.experiment.add_audio( - "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.sample_rate ) mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) @@ -286,7 +295,7 @@ def global_step(self): return self.trainer.fit_loop.epoch_loop.total_batch_idx def on_train_batch_start(self, *args): - if self.global_step >= self.hparams.pretrain_mel_steps: + if self.global_step >= self.pretrain_mel_steps: self.train_discriminator = True else: self.train_discriminator = False @@ -294,14 +303,14 @@ def on_train_batch_start(self, *args): def on_train_batch_end(self, *args): def mel_loss_coeff_decay(current_step, num_cycles=0.5): max_steps = self.trainer.max_steps // 2 - if current_step < self.hparams.num_warmup_steps: + if current_step < self.num_warmup_steps: return 1.0 - progress = float(current_step - self.hparams.num_warmup_steps) / float( - max(1, max_steps - self.hparams.num_warmup_steps) + progress = float(current_step - self.num_warmup_steps) / float( + max(1, max_steps - self.num_warmup_steps) ) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - if self.hparams.decay_mel_coeff: + if self.decay_mel_coeff: self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) @@ -365,7 +374,7 @@ def validation_epoch_end(self, outputs): self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) self.logger.experiment.add_audio( - "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, + "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.sample_rate, ) super().validation_epoch_end(outputs)