From 9cef923d991f57b697f749295dc242e5c745ed1e Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 28 Oct 2020 15:24:18 +0100 Subject: [PATCH 1/4] ssim loss for tacotron models --- TTS/tts/configs/config.json | 10 +++- TTS/tts/layers/losses.py | 100 ++++++++++++++++++++++++++++----- TTS/tts/utils/generic_utils.py | 11 +++- tests/test_layers.py | 71 ++++++++++++++++++++++- 4 files changed, 174 insertions(+), 18 deletions(-) diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 1b63b037..4d3e2674 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -69,10 +69,14 @@ // LOSS SETTINGS "loss_masking": true, // enable / disable loss masking against the sequence padding. - "decoder_loss_alpha": 0.5, // decoder loss weight. If > 0, it is enabled - "postnet_loss_alpha": 0.25, // postnet loss weight. If > 0, it is enabled + "decoder_loss_alpha": 0.5, // original decoder loss weight. If > 0, it is enabled + "postnet_loss_alpha": 0.25, // original postnet loss weight. If > 0, it is enabled + "postnet_diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled + "decoder_diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled + "decoder_ssim_alpha": 0.5, // decoder ssim loss weight. If > 0, it is enabled + "postnet_ssim_alpha": 0.25, // postnet ssim loss weight. If > 0, it is enabled "ga_alpha": 5.0, // weight for guided attention loss. If > 0, guided attention is enabled. - "diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled + // VALIDATION "run_eval": true, diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 67503a76..10ee3905 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -5,6 +5,7 @@ from torch import nn from inspect import signature from torch.nn import functional from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.utils.ssim import ssim # pylint: disable=abstract-method Method @@ -25,6 +26,10 @@ class L1LossMasked(nn.Module): class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. + Shapes: + x: B x T X D + target: B x T x D + length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ @@ -63,6 +68,10 @@ class MSELossMasked(nn.Module): class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. + Shapes: + x: B x T X D + target: B x T x D + length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ @@ -87,6 +96,33 @@ class MSELossMasked(nn.Module): return loss +class SSIMLoss(torch.nn.Module): + """SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity""" + def __init__(self): + super().__init__() + self.loss_func = ssim + + def forward(self, y_hat, y, length=None): + """ + Args: + y_hat (tensor): model prediction values. + y (tensor): target values. + length (tensor): length of each sample in a batch. + Shapes: + y_hat: B x T X D + y: B x T x D + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + if length is not None: + m = sequence_mask(sequence_length=length, + max_len=y.size(1)).unsqueeze(2).float().to( + y_hat.device) + y_hat, y = y_hat * m, y * m + return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1)) + + class AttentionEntropyLoss(nn.Module): # pylint: disable=R0201 def forward(self, align): @@ -118,6 +154,10 @@ class BCELossMasked(nn.Module): class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. + Shapes: + x: B x T + target: B x T + length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ @@ -142,13 +182,20 @@ class DifferentailSpectralLoss(nn.Module): super().__init__() self.loss_func = loss_func - def forward(self, x, target, length): + def forward(self, x, target, length=None): + """ + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ x_diff = x[:, 1:] - x[:, :-1] target_diff = target[:, 1:] - target[:, :-1] - if len(signature(self.loss_func).parameters) > 2: - return self.loss_func(x_diff, target_diff, length-1) - # if loss masking is not enabled - return self.loss_func(x_diff, target_diff) + if length is None: + return self.loss_func(x_diff, target_diff) + return self.loss_func(x_diff, target_diff, length-1) class GuidedAttentionLoss(torch.nn.Module): @@ -188,6 +235,7 @@ class GuidedAttentionLoss(torch.nn.Module): class TacotronLoss(torch.nn.Module): + """Collection of Tacotron set-up based on provided config.""" def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): super(TacotronLoss, self).__init__() self.stopnet_pos_weight = stopnet_pos_weight @@ -195,6 +243,7 @@ class TacotronLoss(torch.nn.Module): self.diff_spec_alpha = c.diff_spec_alpha self.decoder_alpha = c.decoder_loss_alpha self.postnet_alpha = c.postnet_loss_alpha + self.ssim_alpha = c.ssim_alpha self.config = c # postnet and decoder loss @@ -205,12 +254,15 @@ class TacotronLoss(torch.nn.Module): else: self.criterion = nn.L1Loss() if c.model in ["Tacotron" ] else nn.MSELoss() - # differential spectral loss - if c.diff_spec_alpha > 0: - self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion) # guided attention loss if c.ga_alpha > 0: self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) + # differential spectral loss + if c.postnet_diff_spec_alpha > 0 or c.decoder_diff_spec_alpha > 0: + self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion) + # ssim loss + if c.postnet_ssim_alpha > 0 or c.decoder_ssim_alpha > 0: + self.criterion_ssim = SSIMLoss() # stopnet loss # pylint: disable=not-callable self.criterion_st = BCELossMasked( @@ -221,6 +273,9 @@ class TacotronLoss(torch.nn.Module): alignments, alignment_lens, alignments_backwards, input_lens): return_dict = {} + # remove lengths if no masking is applied + if not self.config.loss_masking: + output_lens = None # decoder and postnet losses if self.config.loss_masking: if self.decoder_alpha > 0: @@ -285,11 +340,30 @@ class TacotronLoss(torch.nn.Module): loss += ga_loss * self.ga_alpha return_dict['ga_loss'] = ga_loss * self.ga_alpha - # differential spectral loss - if self.config.diff_spec_alpha > 0: - diff_spec_loss = self.criterion_diff_spec(postnet_output, mel_input, output_lens) - loss += diff_spec_loss * self.diff_spec_alpha - return_dict['diff_spec_loss'] = diff_spec_loss + # decoder differential spectral loss + if self.config.decoder_diff_spec_alpha > 0: + decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) + loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha + return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss + + # postnet differential spectral loss + if self.config.postnet_diff_spec_alpha > 0: + postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, mel_input, output_lens) + loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha + return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss + + # decoder ssim loss + if self.config.decoder_ssim_alpha > 0: + decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) + loss += decoder_ssim_loss * self.postnet_ssim_alpha + return_dict['decoder_ssim_loss'] = decoder_ssim_loss + + # postnet ssim loss + if self.config.postnet_ssim_alpha > 0: + postnet_ssim_loss = self.criterion_ssim(postnet_output, mel_input, output_lens) + loss += postnet_ssim_loss * self.postnet_ssim_alpha + return_dict['postnet_ssim_loss'] = postnet_ssim_loss + return_dict['loss'] = loss return return_dict diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 2361fa85..2c82611f 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -178,10 +178,19 @@ def check_config_tts(c): check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) check_argument('r', c, restricted=True, val_type=int, min_val=1) check_argument('gradual_training', c, restricted=False, val_type=list) - check_argument('loss_masking', c, restricted=True, val_type=bool) check_argument('apex_amp_level', c, restricted=False, val_type=str) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) + # loss parameters + check_argument('loss_masking', c, restricted=True, val_type=bool) + check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) + # validation parameters check_argument('run_eval', c, restricted=True, val_type=bool) check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) diff --git a/tests/test_layers.py b/tests/test_layers.py index 57be51e5..5426e195 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,7 +2,7 @@ import unittest import torch as T from TTS.tts.layers.tacotron import Prenet, CBHG, Decoder, Encoder -from TTS.tts.layers.losses import L1LossMasked +from TTS.tts.layers.losses import L1LossMasked, SSIMLoss from TTS.tts.utils.generic_utils import sequence_mask # pylint: disable=unused-variable @@ -149,3 +149,72 @@ class L1LossMaskedTests(unittest.TestCase): (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) + + +class SSIMLossTests(unittest.TestCase): + def test_in_out(self): #pylint: disable=no-self-use + # test input == target + layer = SSIMLoss() + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-4 , "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + # seq_len_norm = True + # test input == target + layer = L1LossMasked(seq_len_norm=True) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + From 59e1cf99d0dc5b00e276c72348737f2613f163e5 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 28 Oct 2020 18:30:00 +0100 Subject: [PATCH 2/4] config update and ssim implementation --- TTS/tts/configs/config.json | 1 + TTS/tts/utils/ssim.py | 75 +++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 TTS/tts/utils/ssim.py diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 4d3e2674..2cad69c3 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -76,6 +76,7 @@ "decoder_ssim_alpha": 0.5, // decoder ssim loss weight. If > 0, it is enabled "postnet_ssim_alpha": 0.25, // postnet ssim loss weight. If > 0, it is enabled "ga_alpha": 5.0, // weight for guided attention loss. If > 0, guided attention is enabled. + "stopnet_pos_weight": 15.0, // pos class weight for stopnet loss since there are way more negative samples than positive samples. // VALIDATION diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py new file mode 100644 index 00000000..c370f5e5 --- /dev/null +++ b/TTS/tts/utils/ssim.py @@ -0,0 +1,75 @@ +# taken from https://github.com/Po-Hsun-Su/pytorch-ssim + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file From e49cc3bbcdd9d82e3556b1b2b59cc98a1742984e Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 28 Oct 2020 18:34:34 +0100 Subject: [PATCH 3/4] bug fix --- TTS/tts/layers/losses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 10ee3905..efd0c2cb 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -240,10 +240,12 @@ class TacotronLoss(torch.nn.Module): super(TacotronLoss, self).__init__() self.stopnet_pos_weight = stopnet_pos_weight self.ga_alpha = c.ga_alpha - self.diff_spec_alpha = c.diff_spec_alpha + self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha + self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha self.decoder_alpha = c.decoder_loss_alpha self.postnet_alpha = c.postnet_loss_alpha - self.ssim_alpha = c.ssim_alpha + self.decoder_ssim_alpha = c.decoder_ssim_alpha + self.postnet_ssim_alpha = c.postnet_ssim_alpha self.config = c # postnet and decoder loss From fdaed45f58712427067c433d981ea270bd5f1d63 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 28 Oct 2020 18:40:54 +0100 Subject: [PATCH 4/4] optional loss masking for stoptoken predictor --- TTS/tts/layers/losses.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index efd0c2cb..f26cb884 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -163,14 +163,20 @@ class BCELossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).float() + if length is not None: + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).float() + x = x * mask + target = target * mask + num_items = mask.sum() + else: + num_items = torch.numel(x) loss = functional.binary_cross_entropy_with_logits( - x * mask, - target * mask, + x, + target, pos_weight=self.pos_weight, reduction='sum') - loss = loss / mask.sum() + loss = loss / num_items return loss