From acbafb456bdc460c5adeb7b1394d418f8f2f6758 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 28 Sep 2019 15:44:17 +0200 Subject: [PATCH] Weighting positive values for stopnet loss, change adam_weight_decay name --- train.py | 8 ++++---- utils/generic_utils.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 7a68e2b0..cbcfb1ec 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters, create_experiment_folder, get_git_branch, load_config, remove_experiment_folder, - save_best_model, save_checkpoint, weight_decay, + save_best_model, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, setup_model, split_dataset, gradual_training_scheduler, KeepAverage, set_weight_decay) @@ -187,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, loss += stop_loss loss.backward() - optimizer, current_lr = weight_decay(optimizer) + optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() @@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # backpass and check the grad norm for stop loss if c.separate_stopnet: stop_loss.backward() - optimizer_st, _ = weight_decay(optimizer_st) + optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: @@ -526,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name else: criterion = nn.L1Loss() if c.model in [ "Tacotron", "TacotronGST"] else nn.MSELoss() - criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None + criterion_st = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) if c.stopnet else None if args.restore_path: checkpoint = torch.load(args.restore_path) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 3188067f..50d611b8 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -169,7 +169,7 @@ def lr_decay(init_lr, global_step, warmup_steps): return lr -def weight_decay(optimizer): +def adam_weight_decay(optimizer): """ Custom weight decay operation, not effecting grad values. """ @@ -181,7 +181,7 @@ def weight_decay(optimizer): param.data) return optimizer, current_lr - +# pylint: disable=dangerous-default-value def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): """ Skip biases, BatchNorm parameters, rnns. @@ -316,7 +316,7 @@ def split_dataset(items): is_multi_speaker = False speakers = [item[-1] for item in items] is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = 500 if 500 < len(items) * 0.01 else int( + eval_split_size = 500 if len(items) * 0.01 > 500 else int( len(items) * 0.01) np.random.seed(0) np.random.shuffle(items)