diff --git a/amt/data.py b/amt/data.py index d904c8e..98f9244 100644 --- a/amt/data.py +++ b/amt/data.py @@ -111,12 +111,9 @@ def get_wav_segments( yield buffer if pad_last == True: - yield torch.cat( - ( - buffer[stride_samples:], - torch.zeros(stride_samples, dtype=torch.float32), - ), - dim=0, + yield torch.nn.functional.pad( + buffer[stride_samples:], + (0.0, chunk_samples - len(buffer[stride_samples:])), )