unit test fix

This commit is contained in:
Aya Jafari 2023-10-13 10:56:47 -03:00
parent 6eaecab0ca
commit ffddf10458
1 changed files with 2 additions and 3 deletions

View File

@ -395,8 +395,8 @@ class ForwardTTS(BaseTTS):
- x_mask: :math:`(B, 1, T_{en})` - x_mask: :math:`(B, 1, T_{en})`
- g: :math:`(B, C)` - g: :math:`(B, C)`
""" """
g = g.type(torch.LongTensor)
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
g = g.type(torch.LongTensor)
g = self.emb_g(g) # [B, C, 1] g = self.emb_g(g) # [B, C, 1]
if g is not None: if g is not None:
g = g.unsqueeze(-1) g = g.unsqueeze(-1)
@ -684,8 +684,7 @@ class ForwardTTS(BaseTTS):
# encoder pass # encoder pass
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
# duration predictor pass # duration predictor pass
o_en = o_en.squeeze() o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)