From 9921d682c325d6f7159c71969bbfdb228c685329 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Jan 2020 12:30:07 +0100 Subject: [PATCH] seq_len_norm for imbalanced datasets --- layers/losses.py | 40 ++++++++++++++++++++++++++++++++-------- train.py | 4 ++-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/layers/losses.py b/layers/losses.py index e7ecff5f..b8b17c17 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask class L1LossMasked(nn.Module): + + def __init__(self, seq_len_norm): + super(L1LossMasked, self).__init__() + self.seq_len_norm = seq_len_norm + def forward(self, x, target, length): """ Args: @@ -24,14 +29,26 @@ class L1LossMasked(nn.Module): target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() - mask = mask.expand_as(x) - loss = functional.l1_loss( - x * mask, target * mask, reduction="sum") - loss = loss / mask.sum() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.l1_loss( + x * mask, target * mask, reduction='none') + loss = loss.mul(out_weights.cuda()).sum() + else: + loss = functional.l1_loss( + x * mask, target * mask, reduction='sum') + loss = loss / mask.sum() return loss class MSELossMasked(nn.Module): + + def __init__(self, seq_len_norm): + super(MSELossMasked, self).__init__() + self.seq_len_norm = seq_len_norm + def forward(self, x, target, length): """ Args: @@ -50,10 +67,17 @@ class MSELossMasked(nn.Module): target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() - mask = mask.expand_as(x) - loss = functional.mse_loss( - x * mask, target * mask, reduction="sum") - loss = loss / mask.sum() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.mse_loss( + x * mask, target * mask, reduction='none') + loss = loss.mul(out_weights.cuda()).sum() + else: + loss = functional.mse_loss( + x * mask, target * mask, reduction='sum') + loss = loss / mask.sum() return loss diff --git a/train.py b/train.py index 81bc2c72..f52d24c1 100644 --- a/train.py +++ b/train.py @@ -561,8 +561,8 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_st = None if c.loss_masking: - criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST" - ] else MSELossMasked() + criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST" + ] else MSELossMasked(c.seq_len_norm) else: criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" ] else nn.MSELoss()