Skip to content

Commit

Permalink
[TTS] Fix audio codec type checks
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Sep 8, 2023
1 parent 2f2e47d commit 53e12b7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 46 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self, loss_fn, loss_scale: float = 1.0):
@property
def input_types(self):
return {
"target": NeuralType(('B', 'D', 'T'), RegressionValuesType()),
"predicted": NeuralType(('B', 'D', 'T'), PredictionsType()),
"target": NeuralType(('B', 'D', 'T'), RegressionValuesType()),
"target_len": NeuralType(tuple('B'), LengthsType()),
}

Expand Down Expand Up @@ -97,7 +97,7 @@ def input_types(self):
@property
def output_types(self):
return {
"loss": [NeuralType(elements_type=LossType())],
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
Expand Down Expand Up @@ -146,7 +146,7 @@ def input_types(self):
@property
def output_types(self):
return {
"loss": [NeuralType(elements_type=LossType())],
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,11 @@ def configure_optimizers(self):
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)

gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters())
disc_params = self.discriminator.parameters()
vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else []
gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params)
optim_g = instantiate(optim_config, params=gen_params)

disc_params = self.discriminator.parameters()
optim_d = instantiate(optim_config, params=disc_params)

if sched_config is None:
Expand Down
28 changes: 15 additions & 13 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional, Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange

from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType
from nemo.core.neural_types.elements import LengthsType, VoidType
from nemo.core.neural_types.neural_type import NeuralType


Expand Down Expand Up @@ -64,21 +63,22 @@ def __init__(
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, lengths)
out = mask_sequence_tensor(out, input_len)
return out


Expand All @@ -101,21 +101,22 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, lengths)
out = mask_sequence_tensor(out, input_len)
return out


Expand Down Expand Up @@ -151,11 +152,12 @@ def input_types(self):
@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'H', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'H', 'T'), VoidType()),
}

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)

@typecheck()
def forward(self, inputs):
return self.conv(inputs)
64 changes: 36 additions & 28 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,29 @@ def __init__(self, channels: int):
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T_input'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"out": NeuralType(('B', 'C', 'T_out'), VoidType()),
}

def remove_weight_norm(self):
self.pre_conv.remove_weight_norm()
self.res_conv1.remove_weight_norm()
self.res_conv2.remove_weight_norm()

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
res = self.activation(inputs)
res = self.res_conv1(res, lengths)
res = self.res_conv1(inputs=res, input_len=input_len)
res = self.activation(res)
res = self.res_conv2(res, lengths)
res = self.res_conv2(inputs=res, input_len=input_len)

out = self.pre_conv(inputs, lengths) + res
out = mask_sequence_tensor(out, lengths)
out = self.pre_conv(inputs=inputs, input_len=input_len) + res
out = mask_sequence_tensor(out, input_len)
return out


Expand All @@ -112,20 +113,21 @@ def __init__(self, dim: int, num_layers: int, rnn_type: str = "lstm", use_skip:
def input_types(self):
return {
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
"lengths": NeuralType(tuple('B'), LengthsType()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"out": [NeuralType(('B', 'C', 'T'), VoidType())],
"out": NeuralType(('B', 'C', 'T'), VoidType()),
}

def forward(self, inputs, lengths):
@typecheck()
def forward(self, inputs, input_len):
inputs = rearrange(inputs, "B C T -> B T C")

packed_inputs = nn.utils.rnn.pack_padded_sequence(
inputs, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False
inputs, lengths=input_len.cpu(), batch_first=True, enforce_sorted=False
)
packed_out, _ = self.rnn(packed_inputs)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
Expand Down Expand Up @@ -183,15 +185,15 @@ def __init__(
@property
def input_types(self):
return {
"audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()),
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"encoded": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())],
"encoded_len": [NeuralType(tuple('B'), LengthsType())],
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
}

def remove_weight_norm(self):
Expand All @@ -201,26 +203,27 @@ def remove_weight_norm(self):
for down_sample_conv in self.down_sample_conv_layers:
down_sample_conv.remove_weight_norm()

@typecheck()
def forward(self, audio, audio_len):
encoded_len = audio_len
audio = rearrange(audio, "B T -> B 1 T")
# [B, C, T_audio]
out = self.pre_conv(audio, encoded_len)
out = self.pre_conv(inputs=audio, input_len=encoded_len)
for res_block, down_sample_conv, down_sample_rate in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates
):
# [B, C, T]
out = res_block(out, encoded_len)
out = res_block(inputs=out, input_len=encoded_len)
out = self.activation(out)

encoded_len = encoded_len // down_sample_rate
# [B, 2 * C, T / down_sample_rate]
out = down_sample_conv(out, encoded_len)
out = down_sample_conv(inputs=out, input_len=encoded_len)

out = self.rnn(out, encoded_len)
out = self.rnn(inputs=out, input_len=encoded_len)
out = self.activation(out)
# [B, encoded_dim, T_encoded]
encoded = self.post_conv(out, encoded_len)
encoded = self.post_conv(inputs=out, input_len=encoded_len)
return encoded, encoded_len


Expand Down Expand Up @@ -274,7 +277,7 @@ def input_types(self):
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()),
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
}

Expand All @@ -285,23 +288,24 @@ def remove_weight_norm(self):
for res_block in self.res_blocks:
res_block.remove_weight_norm()

@typecheck()
def forward(self, inputs, input_len):
audio_len = input_len
# [B, C, T_encoded]
out = self.pre_conv(inputs, audio_len)
out = self.rnn(out, audio_len)
out = self.pre_conv(inputs=inputs, input_len=audio_len)
out = self.rnn(inputs=out, input_len=audio_len)
for res_block, up_sample_conv, up_sample_rate in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates
):
audio_len = audio_len * up_sample_rate
out = self.activation(out)
# [B, C / 2, T * up_sample_rate]
out = up_sample_conv(out, audio_len)
out = res_block(out, audio_len)
out = up_sample_conv(inputs=out, input_len=audio_len)
out = res_block(inputs=out, input_len=audio_len)

out = self.activation(out)
# [B, 1, T_audio]
out = self.post_conv(out, audio_len)
out = self.post_conv(inputs=out, input_len=audio_len)
audio = self.out_activation(out)
audio = rearrange(audio, "B 1 T -> B T")
return audio, audio_len
Expand Down Expand Up @@ -356,18 +360,19 @@ def output_types(self):
"fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())],
}

@typecheck()
def forward(self, audio):
fmap = []

# [batch, 2, T_spec, fft]
out = self.stft(audio)
for conv in self.conv_layers:
# [batch, filters, T_spec, fft // 2**i]
out = conv(out)
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, T_spec, fft // 8]
scores = self.conv_post(out)
scores = self.conv_post(inputs=out)
fmap.append(scores)
scores = rearrange(scores, "B 1 T C -> B C T")

Expand All @@ -382,7 +387,7 @@ def __init__(self, resolutions):
@property
def input_types(self):
return {
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

Expand All @@ -395,6 +400,7 @@ def output_types(self):
"fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
Expand Down Expand Up @@ -627,6 +633,7 @@ def output_types(self):
"indices": NeuralType(('B', 'T'), Index()),
}

@typecheck()
def forward(self, inputs, input_len):
input_flat = rearrange(inputs, "B T D -> (B T) D")
self._init_codes(input_flat)
Expand Down Expand Up @@ -746,6 +753,7 @@ def output_types(self):
"commit_loss": NeuralType((), LossType()),
}

@typecheck()
def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]:
commit_loss = 0.0
residual = rearrange(inputs, "B D T -> B T D")
Expand Down

0 comments on commit 53e12b7

Please sign in to comment.