Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Fix audio codec type checks #7373

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self, loss_fn, loss_scale: float = 1.0):
@property
def input_types(self):
return {
"target": NeuralType(('B', 'D', 'T'), RegressionValuesType()),
"predicted": NeuralType(('B', 'D', 'T'), PredictionsType()),
"target": NeuralType(('B', 'D', 'T'), RegressionValuesType()),
"target_len": NeuralType(tuple('B'), LengthsType()),
}

Expand Down Expand Up @@ -97,7 +97,7 @@ def input_types(self):
@property
def output_types(self):
return {
"loss": [NeuralType(elements_type=LossType())],
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
Expand Down Expand Up @@ -146,7 +146,7 @@ def input_types(self):
@property
def output_types(self):
return {
"loss": [NeuralType(elements_type=LossType())],
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,11 @@ def configure_optimizers(self):
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)

gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters())
disc_params = self.discriminator.parameters()
vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else []
gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params)
optim_g = instantiate(optim_config, params=gen_params)

disc_params = self.discriminator.parameters()
optim_d = instantiate(optim_config, params=disc_params)

if sched_config is None:
Expand Down
28 changes: 15 additions & 13 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional, Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange

from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType
from nemo.core.neural_types.elements import LengthsType, VoidType
from nemo.core.neural_types.neural_type import NeuralType


Expand Down Expand Up @@ -64,21 +63,22 @@ def __init__(
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, lengths)
out = mask_sequence_tensor(out, input_len)
return out


Expand All @@ -101,21 +101,22 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, lengths)
out = mask_sequence_tensor(out, input_len)
return out


Expand Down Expand Up @@ -151,11 +152,12 @@ def input_types(self):
@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'H', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'H', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

@typecheck()
def forward(self, inputs):
return self.conv(inputs)
64 changes: 36 additions & 28 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,29 @@ def __init__(self, channels: int):
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T_input'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"out": NeuralType(('B', 'C', 'T_out'), VoidType()),
}

def remove_weight_norm(self):
self.pre_conv.remove_weight_norm()
self.res_conv1.remove_weight_norm()
self.res_conv2.remove_weight_norm()

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
res = self.activation(inputs)
res = self.res_conv1(res, lengths)
res = self.res_conv1(inputs=res, input_len=input_len)
res = self.activation(res)
res = self.res_conv2(res, lengths)
res = self.res_conv2(inputs=res, input_len=input_len)

out = self.pre_conv(inputs, lengths) + res
out = mask_sequence_tensor(out, lengths)
out = self.pre_conv(inputs=inputs, input_len=input_len) + res
out = mask_sequence_tensor(out, input_len)
return out


Expand All @@ -112,20 +113,21 @@ def __init__(self, dim: int, num_layers: int, rnn_type: str = "lstm", use_skip:
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
inputs = rearrange(inputs, "B C T -> B T C")

packed_inputs = nn.utils.rnn.pack_padded_sequence(
inputs, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False
inputs, lengths=input_len.cpu(), batch_first=True, enforce_sorted=False
)
packed_out, _ = self.rnn(packed_inputs)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
Expand Down Expand Up @@ -183,15 +185,15 @@ def __init__(
@property
def input_types(self):
return {
"audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()),
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"encoded": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())],
"encoded_len": [NeuralType(tuple('B'), LengthsType())],
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
}

def remove_weight_norm(self):
Expand All @@ -201,26 +203,27 @@ def remove_weight_norm(self):
for down_sample_conv in self.down_sample_conv_layers:
down_sample_conv.remove_weight_norm()

@typecheck()
def forward(self, audio, audio_len):
encoded_len = audio_len
audio = rearrange(audio, "B T -> B 1 T")
# [B, C, T_audio]
out = self.pre_conv(audio, encoded_len)
out = self.pre_conv(inputs=audio, input_len=encoded_len)
for res_block, down_sample_conv, down_sample_rate in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates
):
# [B, C, T]
out = res_block(out, encoded_len)
out = res_block(inputs=out, input_len=encoded_len)
out = self.activation(out)

encoded_len = encoded_len // down_sample_rate
# [B, 2 * C, T / down_sample_rate]
out = down_sample_conv(out, encoded_len)
out = down_sample_conv(inputs=out, input_len=encoded_len)

out = self.rnn(out, encoded_len)
out = self.rnn(inputs=out, input_len=encoded_len)
out = self.activation(out)
# [B, encoded_dim, T_encoded]
encoded = self.post_conv(out, encoded_len)
encoded = self.post_conv(inputs=out, input_len=encoded_len)
return encoded, encoded_len


Expand Down Expand Up @@ -274,7 +277,7 @@ def input_types(self):
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()),
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
}

Expand All @@ -285,23 +288,24 @@ def remove_weight_norm(self):
for res_block in self.res_blocks:
res_block.remove_weight_norm()

@typecheck()
def forward(self, inputs, input_len):
audio_len = input_len
# [B, C, T_encoded]
out = self.pre_conv(inputs, audio_len)
out = self.rnn(out, audio_len)
out = self.pre_conv(inputs=inputs, input_len=audio_len)
out = self.rnn(inputs=out, input_len=audio_len)
for res_block, up_sample_conv, up_sample_rate in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates
):
audio_len = audio_len * up_sample_rate
out = self.activation(out)
# [B, C / 2, T * up_sample_rate]
out = up_sample_conv(out, audio_len)
out = res_block(out, audio_len)
out = up_sample_conv(inputs=out, input_len=audio_len)
out = res_block(inputs=out, input_len=audio_len)

out = self.activation(out)
# [B, 1, T_audio]
out = self.post_conv(out, audio_len)
out = self.post_conv(inputs=out, input_len=audio_len)
audio = self.out_activation(out)
audio = rearrange(audio, "B 1 T -> B T")
return audio, audio_len
Expand Down Expand Up @@ -356,18 +360,19 @@ def output_types(self):
"fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())],
}

@typecheck()
def forward(self, audio):
fmap = []

# [batch, 2, T_spec, fft]
out = self.stft(audio)
for conv in self.conv_layers:
# [batch, filters, T_spec, fft // 2**i]
out = conv(out)
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, T_spec, fft // 8]
scores = self.conv_post(out)
scores = self.conv_post(inputs=out)
fmap.append(scores)
scores = rearrange(scores, "B 1 T C -> B C T")

Expand All @@ -382,7 +387,7 @@ def __init__(self, resolutions):
@property
def input_types(self):
return {
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

Expand All @@ -395,6 +400,7 @@ def output_types(self):
"fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
Expand Down Expand Up @@ -627,6 +633,7 @@ def output_types(self):
"indices": NeuralType(('B', 'T'), Index()),
}

@typecheck()
def forward(self, inputs, input_len):
input_flat = rearrange(inputs, "B T D -> (B T) D")
self._init_codes(input_flat)
Expand Down Expand Up @@ -746,6 +753,7 @@ def output_types(self):
"commit_loss": NeuralType((), LossType()),
}

@typecheck()
def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]:
commit_loss = 0.0
residual = rearrange(inputs, "B D T -> B T D")
Expand Down
6 changes: 3 additions & 3 deletions tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_conv1d(self):
lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32)

conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size)
out = conv(inputs, lengths)
out = conv(inputs=inputs, input_len=lengths)

assert out.shape == (self.batch_size, self.out_channels, self.max_len)
assert torch.all(out[0, :, : self.len1] != 0.0)
Expand All @@ -66,7 +66,7 @@ def test_conv1d_downsample(self):
stride=stride,
padding=padding,
)
out = conv(inputs, lengths)
out = conv(inputs=inputs, input_len=lengths)

assert out.shape == (self.batch_size, self.out_channels, out_len)
assert torch.all(out[0, :, :out_len_1] != 0.0)
Expand All @@ -87,7 +87,7 @@ def test_conv1d_transpose_upsample(self):
conv = ConvTranspose1dNorm(
in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride
)
out = conv(inputs, lengths)
out = conv(inputs=inputs, input_len=lengths)

assert out.shape == (self.batch_size, self.out_channels, out_len)
assert torch.all(out[0, :, :out_len_1] != 0.0)
Expand Down
Loading