mirror of https://github.com/coqui-ai/TTS.git
reset the way ga_loss is stored in return_dict
This commit is contained in:
parent
a108d0ee81
commit
d94782a076
|
@ -346,7 +346,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
if self.config.ga_alpha > 0:
|
if self.config.ga_alpha > 0:
|
||||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||||
loss += ga_loss * self.ga_alpha
|
loss += ga_loss * self.ga_alpha
|
||||||
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
return_dict['ga_loss'] = ga_loss
|
||||||
|
|
||||||
# decoder differential spectral loss
|
# decoder differential spectral loss
|
||||||
if self.config.decoder_diff_spec_alpha > 0:
|
if self.config.decoder_diff_spec_alpha > 0:
|
||||||
|
|
Loading…
Reference in New Issue