mirror of https://github.com/coqui-ai/TTS.git
add mk annealing (mk attn loss contribution)
This commit is contained in:
parent
e7350346bf
commit
c55900b4ad
55
config.json
55
config.json
|
@ -1,33 +1,34 @@
|
||||||
{
|
{
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1025,
|
"num_freq": 1025,
|
||||||
"sample_rate": 22050,
|
"sample_rate": 22050,
|
||||||
"frame_length_ms": 50,
|
"frame_length_ms": 50,
|
||||||
"frame_shift_ms": 12.5,
|
"frame_shift_ms": 12.5,
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
"min_level_db": -100,
|
"min_level_db": -100,
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 500,
|
"epochs": 500,
|
||||||
"lr": 0.002,
|
"lr": 0.002,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":32,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
"mk": 1,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.2,
|
"power": 1.2,
|
||||||
|
|
||||||
"dataset": "LJSpeech",
|
"dataset": "LJSpeech",
|
||||||
"meta_file_train": "metadata_train.csv",
|
"meta_file_train": "metadata_train.csv",
|
||||||
"meta_file_val": "metadata_val.csv",
|
"meta_file_val": "metadata_val.csv",
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0/",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0/",
|
||||||
"min_seq_len": 0,
|
"min_seq_len": 0,
|
||||||
"num_loader_workers": 8,
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
"checkpoint": true,
|
"checkpoint": true,
|
||||||
"save_step": 908,
|
"save_step": 908,
|
||||||
"output_path": "/data/shared/erogol_models/"
|
"output_path": "/data/shared/erogol_models/"
|
||||||
}
|
}
|
||||||
|
|
21
train.py
21
train.py
|
@ -19,7 +19,8 @@ from tensorboardX import SummaryWriter
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay,
|
save_best_model, load_config, lr_decay,
|
||||||
count_parameters, check_update, get_commit_hash)
|
count_parameters, check_update, get_commit_hash,
|
||||||
|
create_attn_mask)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
@ -90,6 +91,9 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
params_group['lr'] = current_lr
|
params_group['lr'] = current_lr
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# setup mk
|
||||||
|
mk = mk_decay(c.mk, c.epochs, epoch)
|
||||||
|
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
|
@ -105,19 +109,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# create attention mask
|
# create attention mask
|
||||||
# TODO: vectorize
|
|
||||||
N = text_input_var.shape[1]
|
N = text_input_var.shape[1]
|
||||||
T = mel_spec_var.shape[1] // c.r
|
T = mel_spec_var.shape[1] // c.r
|
||||||
M = np.zeros([N, T])
|
M = create_attn_mask(N, T, g)
|
||||||
for t in range(T):
|
|
||||||
for n in range(N):
|
|
||||||
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/0.05)
|
|
||||||
M[n, t] = val
|
|
||||||
e_x = np.exp(M - np.max(M))
|
|
||||||
M = e_x / e_x.sum(axis=0) # only difference
|
|
||||||
M = Variable(torch.FloatTensor(M).t()).cuda()
|
|
||||||
M = torch.stack([M]*32)
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
@ -129,7 +124,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
linear_spec_var[:, :, :n_priority_freq],
|
linear_spec_var[:, :, :n_priority_freq],
|
||||||
mel_lengths_var)
|
mel_lengths_var)
|
||||||
attention_loss = criterion(alignments, M, mel_lengths_var)
|
attention_loss = criterion(alignments, M, mel_lengths_var)
|
||||||
loss = mel_loss + linear_loss + attention_loss
|
loss = mel_loss + linear_loss + mk * attention_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
|
@ -131,6 +131,24 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn_mask(N, T, g=0.05):
|
||||||
|
r'''creating attn mask for guided attention'''
|
||||||
|
M = np.zeros([N, T])
|
||||||
|
for t in range(T):
|
||||||
|
for n in range(N):
|
||||||
|
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g)
|
||||||
|
M[n, t] = val
|
||||||
|
e_x = np.exp(M - np.max(M))
|
||||||
|
M = e_x / e_x.sum(axis=0) # only difference
|
||||||
|
M = Variable(torch.FloatTensor(M).t()).cuda()
|
||||||
|
M = torch.stack([M]*32)
|
||||||
|
return M
|
||||||
|
|
||||||
|
|
||||||
|
def mk_decay(init_mk, max_epoch, n_epoch):
|
||||||
|
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
def count_parameters(model):
|
||||||
r"""Count number of trainable parameters in a network"""
|
r"""Count number of trainable parameters in a network"""
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
Loading…
Reference in New Issue