mirror of https://github.com/coqui-ai/TTS.git
unit test fix
This commit is contained in:
parent
6eaecab0ca
commit
ffddf10458
|
@ -395,8 +395,8 @@ class ForwardTTS(BaseTTS):
|
||||||
- x_mask: :math:`(B, 1, T_{en})`
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
- g: :math:`(B, C)`
|
- g: :math:`(B, C)`
|
||||||
"""
|
"""
|
||||||
g = g.type(torch.LongTensor)
|
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
|
g = g.type(torch.LongTensor)
|
||||||
g = self.emb_g(g) # [B, C, 1]
|
g = self.emb_g(g) # [B, C, 1]
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g = g.unsqueeze(-1)
|
g = g.unsqueeze(-1)
|
||||||
|
@ -684,8 +684,7 @@ class ForwardTTS(BaseTTS):
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_en = o_en.squeeze()
|
o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
|
||||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue