mirror of https://github.com/coqui-ai/TTS.git
pass num_chars in train.py
This commit is contained in:
parent
328db7757d
commit
b011dafbab
|
@ -2,12 +2,12 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from utils.text.symbols import symbols, phonemes
|
|
||||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
num_chars,
|
||||||
embedding_dim=256,
|
embedding_dim=256,
|
||||||
linear_dim=1025,
|
linear_dim=1025,
|
||||||
mel_dim=80,
|
mel_dim=80,
|
||||||
|
@ -19,8 +19,8 @@ class Tacotron(nn.Module):
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
len(phonemes), embedding_dim, padding_idx=padding_idx)
|
num_chars, embedding_dim, padding_idx=padding_idx)
|
||||||
print(" | > Number of characters : {}".format(len(phonemes)))
|
print(" | > Number of characters : {}".format(num_chars))
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(256, mel_dim, r, attn_windowing)
|
self.decoder = Decoder(256, mel_dim, r, attn_windowing)
|
||||||
|
|
4
train.py
4
train.py
|
@ -17,6 +17,7 @@ from utils.generic_utils import (
|
||||||
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
||||||
get_commit_hash, sequence_mask, NoamLR)
|
get_commit_hash, sequence_mask, NoamLR)
|
||||||
|
from utils.text.symbols import symbols, phonemes
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
@ -355,7 +356,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
|
model = Tacotron(num_chars, c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
Loading…
Reference in New Issue