TTS Loss aggregated all loss functuons

This commit is contained in:
erogol 2020-03-09 21:03:18 +01:00
parent 57cbb3ab0c
commit d83b58e35d
1 changed files with 109 additions and 0 deletions

View File

@ -125,3 +125,112 @@ class BCELossMasked(nn.Module):
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
loss = loss / mask.sum()
return loss
class GuidedAttentionLoss(torch.nn.Module):
def __init__(self, sigma=0.4):
super(GuidedAttentionLoss, self).__init__()
self.sigma = sigma
def _make_ga_masks(self, ilens, olens):
B = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
ga_masks = torch.zeros((B, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma)
return ga_masks
def forward(self, att_ws, ilens, olens):
ga_masks = self._make_ga_masks(ilens, olens).to(att_ws.device)
seq_masks = self._make_masks(ilens, olens).to(att_ws.device)
losses = ga_masks * att_ws
loss = torch.mean(losses.masked_select(seq_masks))
return loss
@staticmethod
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
@staticmethod
def _make_masks(ilens, olens):
in_masks = sequence_mask(ilens)
out_masks = sequence_mask(olens)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)
class TacotronLoss(torch.nn.Module):
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
super(TacotronLoss, self).__init__()
self.stopnet_pos_weight = stopnet_pos_weight
self.ga_alpha = c.ga_alpha
self.config = c
# postnet decoder loss
if c.loss_masking:
self.criterion = L1LossMasked(c.seq_len_norm) if c.model in [
"Tacotron"
] else MSELossMasked(c.seq_len_norm)
else:
self.criterion = nn.L1Loss() if c.model in ["Tacotron"
] else nn.MSELoss()
# guided attention loss
if c.ga_alpha > 0:
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
# stopnet loss
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
stopnet_output, stopnet_target, output_lens, decoder_b_output,
alignments, alignment_lens, input_lens):
return_dict = {}
# decoder and postnet losses
if self.config.loss_masking:
decoder_loss = self.criterion(decoder_output, mel_input,
output_lens)
if self.config.model in ["Tacotron", "TacotronGST"]:
postnet_loss = self.criterion(postnet_output, linear_input,
output_lens)
else:
postnet_loss = self.criterion(postnet_output, mel_input,
output_lens)
else:
decoder_loss = self.criterion(decoder_output, mel_input)
if self.config.model in ["Tacotron", "TacotronGST"]:
postnet_loss = self.criterion(postnet_output, linear_input)
else:
postnet_loss = self.criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss
return_dict['decoder_loss'] = decoder_loss
return_dict['postnet_loss'] = postnet_loss
# stopnet loss
stop_loss = self.criterion_st(
stopnet_output, stopnet_target,
output_lens) if self.config.stopnet else torch.zeros(1)
if not self.config.separate_stopnet and self.config.stopnet:
loss += stop_loss
return_dict['stopnet_loss'] = stop_loss
# backward decoder loss (if enabled)
if self.config.bidirectional_decoder:
if self.config.loss_masking:
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input, output_lens)
else:
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output)
loss += decoder_b_loss + decoder_c_loss
return_dict['decoder_b_loss'] = decoder_b_loss
return_dict['decoder_c_loss'] = decoder_c_loss
# guided attention loss (if enabled)
if self.config.ga_alpha > 0:
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
loss += ga_loss * self.ga_alpha
return_dict['ga_loss'] = ga_loss * self.ga_alpha
return_dict['loss'] = loss
return return_dict