diff --git a/models/tacotron.py b/models/tacotron.py index 8a215b90..5eb7dfac 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,7 +1,8 @@ # coding: utf-8 import torch from torch import nn -from utils.text.symbols import symbols +from math import sqrt +from utils.text.symbols import symbols, phonemes from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG @@ -17,9 +18,11 @@ class Tacotron(nn.Module): self.mel_dim = mel_dim self.linear_dim = linear_dim self.embedding = nn.Embedding( - len(symbols), embedding_dim, padding_idx=padding_idx) - print(" | > Number of characters : {}".format(len(symbols))) - self.embedding.weight.data.normal_(0, 0.3) + len(phonemes), embedding_dim, padding_idx=padding_idx) + print(" | > Number of characters : {}".format(len(phonemes))) + std = sqrt(2.0 / (len(phonemes) + embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(embedding_dim) self.decoder = Decoder(256, mel_dim, r) self.postnet = PostCBHG(mel_dim)