reset the way ga_loss is stored in return_dict

This commit is contained in:
erogol 2020-11-02 13:18:56 +01:00
parent a108d0ee81
commit d94782a076
1 changed files with 1 additions and 1 deletions

View File

@ -346,7 +346,7 @@ class TacotronLoss(torch.nn.Module):
if self.config.ga_alpha > 0:
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
loss += ga_loss * self.ga_alpha
return_dict['ga_loss'] = ga_loss * self.ga_alpha
return_dict['ga_loss'] = ga_loss
# decoder differential spectral loss
if self.config.decoder_diff_spec_alpha > 0: