Make loss masking optional

This commit is contained in:
Eren Golge 2019-04-10 16:41:08 +02:00
parent 8a47b46195
commit e2cf35bb10
2 changed files with 44 additions and 16 deletions

View File

@ -125,11 +125,18 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
else:
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model == "Tacotron":
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss
# backpass and check the grad norm for spec losses
@ -283,11 +290,18 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
else:
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model == "Tacotron":
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss + stop_loss
step_time = time.time() - start_time
@ -383,7 +397,10 @@ def main(args):
optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
if c.loss_masking:
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
else:
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss()
if args.restore_path:

View File

@ -207,8 +207,8 @@ def sequence_mask(sequence_length, max_len=None):
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
# B x T_max
return seq_range_expand < seq_length_expand
@ -239,16 +239,27 @@ def set_init_dict(model_dict, checkpoint, c):
}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
print(" | > {} / {} layers are restored.".format(
len(pretrained_dict), len(model_dict)))
return model_dict
def setup_model(num_chars, c):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = importlib.import_module('models.' + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() == "tacotron":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
model = MyModel(
num_chars=num_chars,
r=c.r,
attn_norm=c.attention_norm,
memory_size=c.memory_size)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
model = MyModel(
num_chars=num_chars,
r=c.r,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent)
return model