seq_len_norm for imbalanced datasets

This commit is contained in:
root 2020-01-15 12:30:07 +01:00 committed by erogol
parent 72817438db
commit 9921d682c3
2 changed files with 34 additions and 10 deletions

View File

@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask
class L1LossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(L1LossMasked, self).__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length):
"""
Args:
@ -24,14 +29,26 @@ class L1LossMasked(nn.Module):
target.requires_grad = False
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = mask.expand_as(x)
loss = functional.l1_loss(
x * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.l1_loss(
x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
loss = functional.l1_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum()
return loss
class MSELossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(MSELossMasked, self).__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length):
"""
Args:
@ -50,10 +67,17 @@ class MSELossMasked(nn.Module):
target.requires_grad = False
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = mask.expand_as(x)
loss = functional.mse_loss(
x * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.mse_loss(
x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
loss = functional.mse_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum()
return loss

View File

@ -561,8 +561,8 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None
if c.loss_masking:
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
] else MSELossMasked()
criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST"
] else MSELossMasked(c.seq_len_norm)
else:
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss()