mirror of https://github.com/coqui-ai/TTS.git
bug fix
This commit is contained in:
parent
754e0d3b63
commit
d2657cbf3a
|
@ -128,7 +128,7 @@ class LJSpeechDataset(Dataset):
|
||||||
linear = torch.FloatTensor(linear)
|
linear = torch.FloatTensor(linear)
|
||||||
mel = torch.FloatTensor(mel)
|
mel = torch.FloatTensor(mel)
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets).squeeze()
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class Tacotron(nn.Module):
|
||||||
# Reshape
|
# Reshape
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
stop_tokens = self.stopnet(mel_outputs)
|
stop_tokens = self.stopnet(mel_outputs).squeeze()
|
||||||
linear_outputs = self.postnet(mel_outputs)
|
linear_outputs = self.postnet(mel_outputs)
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
Loading…
Reference in New Issue