Skip to content

Commit

Permalink
📝 Add debug sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
rishikksh20 committed Jul 30, 2020
1 parent 1c5cb60 commit 0096e8a
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions fastspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ def __init__(self, idim, odim):

def _forward(self, xs, ilens, ys=None, olens=None, ds=None, es=None, ps=None, is_inference=False):
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
# print("Ys :",ys.shape) torch.Size([32, 868, 80])
x_masks = self._source_mask(ilens) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) -> torch.Size([32, 121, 256])
# print("ys :", ys.shape)


# forward duration predictor and length regulator
d_masks = make_pad_mask(ilens).to(xs.device)
Expand All @@ -199,15 +200,15 @@ def _forward(self, xs, ilens, ys=None, olens=None, ds=None, es=None, ps=None, is
one_hot_pitch = pitch_to_one_hot(ps) # (B, Lmax, adim) torch.Size([32, 868, 256])
# print("one_hot_pitch:", one_hot_pitch.shape)
mel_masks = make_pad_mask(olens).to(xs.device)
# print("Before Hs:", hs.shape) torch.Size([32, 121, 256])
#print("Before Hs:", hs.shape) # torch.Size([32, 121, 256])
d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax)
# print("d_outs:", d_outs.shape) torch.Size([32, 121])
#print("d_outs:", d_outs.shape) # torch.Size([32, 121])
hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim)
# print("After Hs:",hs.shape) torch.Size([32, 868, 256])
#print("After Hs:",hs.shape) #torch.Size([32, 868, 256])
e_outs = self.energy_predictor(hs, mel_masks)
# print("e_outs:", e_outs.shape) torch.Size([32, 868])
#print("e_outs:", e_outs.shape) #torch.Size([32, 868])
p_outs = self.pitch_predictor(hs, mel_masks)
# print("p_outs:", p_outs.shape) torch.Size([32, 868])
#print("p_outs:", p_outs.shape) #torch.Size([32, 868])
hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim)
hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim)
# forward decoder
Expand Down Expand Up @@ -246,8 +247,8 @@ def forward(self, xs, ilens, ys, olens, ds, es, ps, *args, **kwargs):
"""
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
xs = xs[:, :max(ilens)] # torch.Size([32, 121]) -> [B, Tmax]
ys = ys[:, :max(olens)] # torch.Size([32, 868, 80]) -> [B, Lmax, odim]

# forward propagation
before_outs, after_outs, d_outs, e_outs, p_outs = self._forward(xs, ilens, ys, olens, ds, es, ps, is_inference=False)
Expand Down

0 comments on commit 0096e8a

Please sign in to comment.