mirror of https://github.com/coqui-ai/TTS.git
reset the way ga_loss is stored in return_dict
This commit is contained in:
parent
a108d0ee81
commit
d94782a076
|
@ -346,7 +346,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
if self.config.ga_alpha > 0:
|
||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||
loss += ga_loss * self.ga_alpha
|
||||
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
||||
return_dict['ga_loss'] = ga_loss
|
||||
|
||||
# decoder differential spectral loss
|
||||
if self.config.decoder_diff_spec_alpha > 0:
|
||||
|
|
Loading…
Reference in New Issue