stop token prediction update for tacotron model

This commit is contained in:
Eren Golge 2018-05-11 04:15:06 -07:00
parent c12e4245b2
commit 1ddff5fc17
1 changed files with 3 additions and 3 deletions

View File

@ -14,7 +14,7 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(len(symbols), embedding_dim, self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx) padding_idx=padding_idx)
print(" | > Number of characted : {}".format(len(symbols))) print(" | > Number of characters : {}".format(len(symbols)))
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r) self.decoder = Decoder(256, mel_dim, r)
@ -27,11 +27,11 @@ class Tacotron(nn.Module):
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs) encoder_outputs, mel_specs)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs) linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments return mel_outputs, linear_outputs, alignments, stop_tokens