seq_len_norm for imbalanced datasets

This commit is contained in:
root 2020-01-15 12:30:07 +01:00 committed by erogol
parent 72817438db
commit 9921d682c3
2 changed files with 34 additions and 10 deletions

View File

@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask
class L1LossMasked(nn.Module): class L1LossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(L1LossMasked, self).__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length): def forward(self, x, target, length):
""" """
Args: Args:
@ -24,14 +29,26 @@ class L1LossMasked(nn.Module):
target.requires_grad = False target.requires_grad = False
mask = sequence_mask( mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x) mask = mask.expand_as(x)
loss = functional.l1_loss( loss = functional.l1_loss(
x * mask, target * mask, reduction="sum") x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
loss = functional.l1_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum() loss = loss / mask.sum()
return loss return loss
class MSELossMasked(nn.Module): class MSELossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(MSELossMasked, self).__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length): def forward(self, x, target, length):
""" """
Args: Args:
@ -50,9 +67,16 @@ class MSELossMasked(nn.Module):
target.requires_grad = False target.requires_grad = False
mask = sequence_mask( mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x) mask = mask.expand_as(x)
loss = functional.mse_loss( loss = functional.mse_loss(
x * mask, target * mask, reduction="sum") x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
loss = functional.mse_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum() loss = loss / mask.sum()
return loss return loss

View File

@ -561,8 +561,8 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None optimizer_st = None
if c.loss_masking: if c.loss_masking:
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST" criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST"
] else MSELossMasked() ] else MSELossMasked(c.seq_len_norm)
else: else:
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss() ] else nn.MSELoss()