Skip to content

Commit

Permalink
only divide by MAX_WAV_VALUE if int16, pad mel/audio if smaller than …
Browse files Browse the repository at this point in the history
…segment_size, switch tqdm, crop mel/wav in train if necessary
  • Loading branch information
lars76 committed Jul 29, 2024
1 parent 455f3f3 commit bff7dd0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
17 changes: 9 additions & 8 deletions meldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def __getitem__(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
audio, sampling_rate = load_wav(filename, self.sampling_rate)
audio = audio / MAX_WAV_VALUE
if np.abs(audio).max() > 1:
audio = audio / MAX_WAV_VALUE
if not self.fine_tuning:
audio = normalize(audio) * 0.95
self.cached_wav = audio
Expand Down Expand Up @@ -328,13 +329,13 @@ def __getitem__(self, index):
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
]
else:
mel = torch.nn.functional.pad(
mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)

mel = torch.nn.functional.pad(
mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)

mel_loss = mel_spectrogram(
audio,
Expand Down
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def validate(rank, a, h, loader, mode="seen"):
print(f"step {steps} {mode} speaker validation...")

# Loop over validation set and compute metrics
for j, batch in tqdm(enumerate(loader)):
for j, batch in enumerate(tqdm(loader)):
x, y, _, y_mel = batch
y = y.to(device)
if hasattr(generator, "module"):
Expand All @@ -326,7 +326,8 @@ def validate(rank, a, h, loader, mode="seen"):
h.fmin,
h.fmax_for_loss,
)
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()

# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if (
Expand All @@ -343,7 +344,8 @@ def validate(rank, a, h, loader, mode="seen"):
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")

# MRSTFT calculation
val_mrstft_tot += loss_mrstft(y_g_hat, y).item()
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()

# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
Expand Down

0 comments on commit bff7dd0

Please sign in to comment.