mirror of https://github.com/coqui-ai/TTS.git
new lr schedule
This commit is contained in:
parent
3519f2f33c
commit
7304978be3
|
@ -16,9 +16,10 @@
|
||||||
"lr_patience": 5,
|
"lr_patience": 5,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 256,
|
"batch_size": 256,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
"r": 5,
|
|
||||||
|
|
||||||
"num_loader_workers": 32,
|
"num_loader_workers": 32,
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -15,6 +15,8 @@ class Tacotron(nn.Module):
|
||||||
self.use_memory_mask = use_memory_mask
|
self.use_memory_mask = use_memory_mask
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||||
padding_idx=padding_idx)
|
padding_idx=padding_idx)
|
||||||
|
print(" | > Embedding dim : {}".format(len(symbols)))
|
||||||
|
|
||||||
# Trying smaller std
|
# Trying smaller std
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
|
|
26
train.py
26
train.py
|
@ -19,7 +19,7 @@ from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
load_config)
|
load_config, lr_decay)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
@ -99,8 +99,9 @@ def main(args):
|
||||||
|
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
|
||||||
lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||||
patience=c.lr_patience, verbose=True)
|
# patience=c.lr_patience, verbose=True)
|
||||||
|
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(c.epochs):
|
||||||
|
|
||||||
|
@ -119,14 +120,19 @@ def main(args):
|
||||||
|
|
||||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||||
|
|
||||||
|
# setup lr
|
||||||
|
current_lr = lr_decay(init_lr, current_step)
|
||||||
|
for params_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = current_lr
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
try:
|
#try:
|
||||||
mel_input = np.concatenate((np.zeros(
|
# mel_input = np.concatenate((np.zeros(
|
||||||
[c.batch_size, 1, c.num_mels], dtype=np.float32),
|
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||||
mel_input[:, 1:, :]), axis=1)
|
# mel_input[:, 1:, :]), axis=1)
|
||||||
except:
|
#except:
|
||||||
raise TypeError("not same dimension")
|
# raise TypeError("not same dimension")
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||||
|
@ -204,7 +210,7 @@ def main(args):
|
||||||
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
||||||
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
||||||
|
|
||||||
lr_scheduler.step(loss.data[0])
|
#lr_scheduler.step(loss.data[0])
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,13 @@ def save_checkpoint(state, filename='checkpoint.pth.tar'):
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def lr_decay(init_lr, global_step):
|
||||||
|
warmup_steps = 4000.0
|
||||||
|
step = global_step + 1.
|
||||||
|
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||||
|
step**-0.5)
|
||||||
|
return lr
|
||||||
|
|
||||||
class Progbar(object):
|
class Progbar(object):
|
||||||
"""Displays a progress bar.
|
"""Displays a progress bar.
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|
Loading…
Reference in New Issue