mirror of https://github.com/coqui-ai/TTS.git
Weighting positive values for stopnet loss, change adam_weight_decay name
This commit is contained in:
parent
99d7f2a666
commit
acbafb456b
8
train.py
8
train.py
|
@ -18,7 +18,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
|
from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
load_config, remove_experiment_folder,
|
load_config, remove_experiment_folder,
|
||||||
save_best_model, save_checkpoint, weight_decay,
|
save_best_model, save_checkpoint, adam_weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model,
|
set_init_dict, copy_config_file, setup_model,
|
||||||
split_dataset, gradual_training_scheduler, KeepAverage,
|
split_dataset, gradual_training_scheduler, KeepAverage,
|
||||||
set_weight_decay)
|
set_weight_decay)
|
||||||
|
@ -187,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
loss += stop_loss
|
loss += stop_loss
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer, current_lr = weight_decay(optimizer)
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
if c.separate_stopnet:
|
if c.separate_stopnet:
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
optimizer_st, _ = weight_decay(optimizer_st)
|
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
optimizer_st.step()
|
optimizer_st.step()
|
||||||
else:
|
else:
|
||||||
|
@ -526,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss() if c.model in [
|
criterion = nn.L1Loss() if c.model in [
|
||||||
"Tacotron", "TacotronGST"] else nn.MSELoss()
|
"Tacotron", "TacotronGST"] else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
|
criterion_st = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
|
|
@ -169,7 +169,7 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
def weight_decay(optimizer):
|
def adam_weight_decay(optimizer):
|
||||||
"""
|
"""
|
||||||
Custom weight decay operation, not effecting grad values.
|
Custom weight decay operation, not effecting grad values.
|
||||||
"""
|
"""
|
||||||
|
@ -181,7 +181,7 @@ def weight_decay(optimizer):
|
||||||
param.data)
|
param.data)
|
||||||
return optimizer, current_lr
|
return optimizer, current_lr
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||||
"""
|
"""
|
||||||
Skip biases, BatchNorm parameters, rnns.
|
Skip biases, BatchNorm parameters, rnns.
|
||||||
|
@ -316,7 +316,7 @@ def split_dataset(items):
|
||||||
is_multi_speaker = False
|
is_multi_speaker = False
|
||||||
speakers = [item[-1] for item in items]
|
speakers = [item[-1] for item in items]
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
eval_split_size = 500 if 500 < len(items) * 0.01 else int(
|
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
||||||
len(items) * 0.01)
|
len(items) * 0.01)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np.random.shuffle(items)
|
np.random.shuffle(items)
|
||||||
|
|
Loading…
Reference in New Issue