mirror of https://github.com/coqui-ai/TTS.git
Make loss masking optional
This commit is contained in:
parent
8a47b46195
commit
e2cf35bb10
35
train.py
35
train.py
|
@ -125,11 +125,18 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
if c.loss_masking:
|
||||||
if c.model == "Tacotron":
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
if c.model == "Tacotron":
|
||||||
|
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||||
|
else:
|
||||||
|
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input)
|
||||||
|
if c.model == "Tacotron":
|
||||||
|
postnet_loss = criterion(postnet_output, linear_input)
|
||||||
|
else:
|
||||||
|
postnet_loss = criterion(postnet_output, mel_input)
|
||||||
loss = decoder_loss + postnet_loss
|
loss = decoder_loss + postnet_loss
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
|
@ -283,11 +290,18 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
if c.loss_masking:
|
||||||
if c.model == "Tacotron":
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
if c.model == "Tacotron":
|
||||||
|
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||||
|
else:
|
||||||
|
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input)
|
||||||
|
if c.model == "Tacotron":
|
||||||
|
postnet_loss = criterion(postnet_output, linear_input)
|
||||||
|
else:
|
||||||
|
postnet_loss = criterion(postnet_output, mel_input)
|
||||||
loss = decoder_loss + postnet_loss + stop_loss
|
loss = decoder_loss + postnet_loss + stop_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -383,7 +397,10 @@ def main(args):
|
||||||
optimizer_st = optim.Adam(
|
optimizer_st = optim.Adam(
|
||||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
if c.loss_masking:
|
||||||
|
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||||
|
else:
|
||||||
|
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss()
|
criterion_st = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
|
|
@ -56,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
|
||||||
""" Create a folder with the current date and time """
|
""" Create a folder with the current date and time """
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
# if debug:
|
# if debug:
|
||||||
# commit_hash = 'debug'
|
# commit_hash = 'debug'
|
||||||
# else:
|
# else:
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(
|
output_folder = os.path.join(
|
||||||
|
@ -207,8 +207,8 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
if sequence_length.is_cuda:
|
if sequence_length.is_cuda:
|
||||||
seq_range_expand = seq_range_expand.cuda()
|
seq_range_expand = seq_range_expand.cuda()
|
||||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
seq_length_expand = (
|
||||||
.expand_as(seq_range_expand))
|
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||||
# B x T_max
|
# B x T_max
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
@ -239,16 +239,27 @@ def set_init_dict(model_dict, checkpoint, c):
|
||||||
}
|
}
|
||||||
# 4. overwrite entries in the existing state dict
|
# 4. overwrite entries in the existing state dict
|
||||||
model_dict.update(pretrained_dict)
|
model_dict.update(pretrained_dict)
|
||||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
print(" | > {} / {} layers are restored.".format(
|
||||||
|
len(pretrained_dict), len(model_dict)))
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, c):
|
def setup_model(num_chars, c):
|
||||||
print(" > Using model: {}".format(c.model))
|
print(" > Using model: {}".format(c.model))
|
||||||
MyModel = importlib.import_module('models.'+c.model.lower())
|
MyModel = importlib.import_module('models.' + c.model.lower())
|
||||||
MyModel = getattr(MyModel, c.model)
|
MyModel = getattr(MyModel, c.model)
|
||||||
if c.model.lower() == "tacotron":
|
if c.model.lower() == "tacotron":
|
||||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
|
model = MyModel(
|
||||||
|
num_chars=num_chars,
|
||||||
|
r=c.r,
|
||||||
|
attn_norm=c.attention_norm,
|
||||||
|
memory_size=c.memory_size)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
|
model = MyModel(
|
||||||
|
num_chars=num_chars,
|
||||||
|
r=c.r,
|
||||||
|
attn_norm=c.attention_norm,
|
||||||
|
prenet_type=c.prenet_type,
|
||||||
|
forward_attn=c.use_forward_attn,
|
||||||
|
trans_agent=c.transition_agent)
|
||||||
return model
|
return model
|
Loading…
Reference in New Issue