mirror of https://github.com/coqui-ai/TTS.git
Make loss masking optional
This commit is contained in:
parent
8a47b46195
commit
e2cf35bb10
35
train.py
35
train.py
|
@ -125,11 +125,18 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
loss = decoder_loss + postnet_loss
|
||||
|
||||
# backpass and check the grad norm for spec losses
|
||||
|
@ -283,11 +290,18 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
loss = decoder_loss + postnet_loss + stop_loss
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
@ -383,7 +397,10 @@ def main(args):
|
|||
optimizer_st = optim.Adam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss()
|
||||
|
||||
if args.restore_path:
|
||||
|
|
|
@ -56,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
|
|||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
# if debug:
|
||||
# commit_hash = 'debug'
|
||||
# commit_hash = 'debug'
|
||||
# else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(
|
||||
|
@ -207,8 +207,8 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
@ -239,16 +239,27 @@ def set_init_dict(model_dict, checkpoint, c):
|
|||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
print(" | > {} / {} layers are restored.".format(
|
||||
len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
def setup_model(num_chars, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('models.'+c.model.lower())
|
||||
MyModel = importlib.import_module('models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() == "tacotron":
|
||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
r=c.r,
|
||||
attn_norm=c.attention_norm,
|
||||
memory_size=c.memory_size)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
r=c.r,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent)
|
||||
return model
|
Loading…
Reference in New Issue