mirror of https://github.com/coqui-ai/TTS.git
Harmonized teacher-forcing
This commit is contained in:
parent
b5f2181e04
commit
56f8b2d19f
|
@ -12,7 +12,7 @@
|
|||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 2000,
|
||||
"lr": 0.003,
|
||||
"lr": 0.001,
|
||||
"batch_size": 180,
|
||||
"r": 5,
|
||||
|
||||
|
|
|
@ -307,9 +307,13 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
# combine prev. model output and prev. real target
|
||||
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||
memory_input = torch.nn.functional.dropout(memory_input,
|
||||
0.1,
|
||||
training=True)
|
||||
# add a random noise
|
||||
memory_input += torch.autograd.Variable(
|
||||
torch.randn(memory_input.size())).type_as(memory_input)
|
||||
noise = torch.autograd.Variable(
|
||||
memory_input.data.new(ins.size()).normal_(0.0, 1.0))
|
||||
memory_input = memory_input + noise
|
||||
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
|
@ -360,5 +364,5 @@ class Decoder(nn.Module):
|
|||
return outputs, alignments
|
||||
|
||||
|
||||
def is_end_of_frames(output, eps=0.1): #0.2
|
||||
def is_end_of_frames(output, eps=0.2): #0.2
|
||||
return (output.data <= eps).all()
|
||||
|
|
23
train.py
23
train.py
|
@ -90,9 +90,6 @@ def main(args):
|
|||
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
|
||||
# tb.add_graph_onnx(model_proto_path)
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
|
||||
if args.restore_step:
|
||||
|
@ -103,10 +100,20 @@ def main(args):
|
|||
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||
start_epoch = checkpoint['step'] // len(dataloader)
|
||||
best_loss = checkpoint['linear_loss']
|
||||
else:
|
||||
elif args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("\n > Model restored from step %d\n" % checkpoint['step'])
|
||||
start_epoch = checkpoint['step'] // len(dataloader)
|
||||
best_loss = checkpoint['linear_loss']
|
||||
start_epoch = 0
|
||||
else:
|
||||
print("\n > Starting a new training")
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params))
|
||||
|
||||
|
@ -142,9 +149,9 @@ def main(args):
|
|||
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step)
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
# current_lr = lr_decay(c.lr, current_step)
|
||||
# for params_group in optimizer.param_groups:
|
||||
# params_group['lr'] = current_lr
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -192,7 +199,7 @@ def main(args):
|
|||
# loss = loss.cuda()
|
||||
|
||||
loss.backward()
|
||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
|
||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
|
|
@ -7,6 +7,7 @@ import datetime
|
|||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
|
Loading…
Reference in New Issue