mirror of https://github.com/coqui-ai/TTS.git
seq_len_norm for imbalanced datasets
This commit is contained in:
parent
72817438db
commit
9921d682c3
|
@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class L1LossMasked(nn.Module):
|
||||
|
||||
def __init__(self, seq_len_norm):
|
||||
super(L1LossMasked, self).__init__()
|
||||
self.seq_len_norm = seq_len_norm
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
Args:
|
||||
|
@ -24,14 +29,26 @@ class L1LossMasked(nn.Module):
|
|||
target.requires_grad = False
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction='none')
|
||||
loss = loss.mul(out_weights.cuda()).sum()
|
||||
else:
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
|
||||
def __init__(self, seq_len_norm):
|
||||
super(MSELossMasked, self).__init__()
|
||||
self.seq_len_norm = seq_len_norm
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
Args:
|
||||
|
@ -50,10 +67,17 @@ class MSELossMasked(nn.Module):
|
|||
target.requires_grad = False
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction='none')
|
||||
loss = loss.mul(out_weights.cuda()).sum()
|
||||
else:
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
|
4
train.py
4
train.py
|
@ -561,8 +561,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_st = None
|
||||
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else MSELossMasked()
|
||||
criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST"
|
||||
] else MSELossMasked(c.seq_len_norm)
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else nn.MSELoss()
|
||||
|
|
Loading…
Reference in New Issue