Weighting positive values for stopnet loss, change adam_weight_decay name

This commit is contained in:
Eren Golge 2019-09-28 15:44:17 +02:00
parent 99d7f2a666
commit acbafb456b
2 changed files with 7 additions and 7 deletions

View File

@ -18,7 +18,7 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
create_experiment_folder, get_git_branch,
load_config, remove_experiment_folder,
save_best_model, save_checkpoint, weight_decay,
save_best_model, save_checkpoint, adam_weight_decay,
set_init_dict, copy_config_file, setup_model,
split_dataset, gradual_training_scheduler, KeepAverage,
set_weight_decay)
@ -187,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loss += stop_loss
loss.backward()
optimizer, current_lr = weight_decay(optimizer)
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# backpass and check the grad norm for stop loss
if c.separate_stopnet:
stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
@ -526,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
else:
criterion = nn.L1Loss() if c.model in [
"Tacotron", "TacotronGST"] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
criterion_st = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) if c.stopnet else None
if args.restore_path:
checkpoint = torch.load(args.restore_path)

View File

@ -169,7 +169,7 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr
def weight_decay(optimizer):
def adam_weight_decay(optimizer):
"""
Custom weight decay operation, not effecting grad values.
"""
@ -181,7 +181,7 @@ def weight_decay(optimizer):
param.data)
return optimizer, current_lr
# pylint: disable=dangerous-default-value
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
Skip biases, BatchNorm parameters, rnns.
@ -316,7 +316,7 @@ def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if 500 < len(items) * 0.01 else int(
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)