pass num_chars in train.py

This commit is contained in:
Eren Golge 2019-01-21 14:52:40 +01:00
parent 328db7757d
commit b011dafbab
2 changed files with 6 additions and 4 deletions

View File

@ -2,12 +2,12 @@
import torch import torch
from torch import nn from torch import nn
from math import sqrt from math import sqrt
from utils.text.symbols import symbols, phonemes
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, def __init__(self,
num_chars,
embedding_dim=256, embedding_dim=256,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
@ -19,8 +19,8 @@ class Tacotron(nn.Module):
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
len(phonemes), embedding_dim, padding_idx=padding_idx) num_chars, embedding_dim, padding_idx=padding_idx)
print(" | > Number of characters : {}".format(len(phonemes))) print(" | > Number of characters : {}".format(num_chars))
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r, attn_windowing) self.decoder = Decoder(256, mel_dim, r, attn_windowing)

View File

@ -17,6 +17,7 @@ from utils.generic_utils import (
remove_experiment_folder, create_experiment_folder, save_checkpoint, remove_experiment_folder, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, count_parameters, check_update, save_best_model, load_config, lr_decay, count_parameters, check_update,
get_commit_hash, sequence_mask, NoamLR) get_commit_hash, sequence_mask, NoamLR)
from utils.text.symbols import symbols, phonemes
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
@ -355,7 +356,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
def main(args): def main(args):
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = Tacotron(num_chars, c.embedding_size, ap.num_freq, ap.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)