mirror of https://github.com/coqui-ai/TTS.git
ssim loss for tacotron models
This commit is contained in:
parent
9d0ae2bfb4
commit
9cef923d99
|
@ -69,10 +69,14 @@
|
||||||
|
|
||||||
// LOSS SETTINGS
|
// LOSS SETTINGS
|
||||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
"decoder_loss_alpha": 0.5, // decoder loss weight. If > 0, it is enabled
|
"decoder_loss_alpha": 0.5, // original decoder loss weight. If > 0, it is enabled
|
||||||
"postnet_loss_alpha": 0.25, // postnet loss weight. If > 0, it is enabled
|
"postnet_loss_alpha": 0.25, // original postnet loss weight. If > 0, it is enabled
|
||||||
|
"postnet_diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled
|
||||||
|
"decoder_diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled
|
||||||
|
"decoder_ssim_alpha": 0.5, // decoder ssim loss weight. If > 0, it is enabled
|
||||||
|
"postnet_ssim_alpha": 0.25, // postnet ssim loss weight. If > 0, it is enabled
|
||||||
"ga_alpha": 5.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
"ga_alpha": 5.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||||
"diff_spec_alpha": 0.25, // differential spectral loss weight. If > 0, it is enabled
|
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
|
from TTS.tts.utils.ssim import ssim
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method Method
|
# pylint: disable=abstract-method Method
|
||||||
|
@ -25,6 +26,10 @@ class L1LossMasked(nn.Module):
|
||||||
class for each corresponding step.
|
class for each corresponding step.
|
||||||
length: A Variable containing a LongTensor of size (batch,)
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
which contains the length of each data in a batch.
|
which contains the length of each data in a batch.
|
||||||
|
Shapes:
|
||||||
|
x: B x T X D
|
||||||
|
target: B x T x D
|
||||||
|
length: B
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
|
@ -63,6 +68,10 @@ class MSELossMasked(nn.Module):
|
||||||
class for each corresponding step.
|
class for each corresponding step.
|
||||||
length: A Variable containing a LongTensor of size (batch,)
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
which contains the length of each data in a batch.
|
which contains the length of each data in a batch.
|
||||||
|
Shapes:
|
||||||
|
x: B x T X D
|
||||||
|
target: B x T x D
|
||||||
|
length: B
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
|
@ -87,6 +96,33 @@ class MSELossMasked(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class SSIMLoss(torch.nn.Module):
|
||||||
|
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_func = ssim
|
||||||
|
|
||||||
|
def forward(self, y_hat, y, length=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y_hat (tensor): model prediction values.
|
||||||
|
y (tensor): target values.
|
||||||
|
length (tensor): length of each sample in a batch.
|
||||||
|
Shapes:
|
||||||
|
y_hat: B x T X D
|
||||||
|
y: B x T x D
|
||||||
|
length: B
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
|
"""
|
||||||
|
if length is not None:
|
||||||
|
m = sequence_mask(sequence_length=length,
|
||||||
|
max_len=y.size(1)).unsqueeze(2).float().to(
|
||||||
|
y_hat.device)
|
||||||
|
y_hat, y = y_hat * m, y * m
|
||||||
|
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
|
||||||
|
|
||||||
|
|
||||||
class AttentionEntropyLoss(nn.Module):
|
class AttentionEntropyLoss(nn.Module):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def forward(self, align):
|
def forward(self, align):
|
||||||
|
@ -118,6 +154,10 @@ class BCELossMasked(nn.Module):
|
||||||
class for each corresponding step.
|
class for each corresponding step.
|
||||||
length: A Variable containing a LongTensor of size (batch,)
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
which contains the length of each data in a batch.
|
which contains the length of each data in a batch.
|
||||||
|
Shapes:
|
||||||
|
x: B x T
|
||||||
|
target: B x T
|
||||||
|
length: B
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
|
@ -142,13 +182,20 @@ class DifferentailSpectralLoss(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss_func = loss_func
|
self.loss_func = loss_func
|
||||||
|
|
||||||
def forward(self, x, target, length):
|
def forward(self, x, target, length=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: B x T
|
||||||
|
target: B x T
|
||||||
|
length: B
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
|
"""
|
||||||
x_diff = x[:, 1:] - x[:, :-1]
|
x_diff = x[:, 1:] - x[:, :-1]
|
||||||
target_diff = target[:, 1:] - target[:, :-1]
|
target_diff = target[:, 1:] - target[:, :-1]
|
||||||
if len(signature(self.loss_func).parameters) > 2:
|
if length is None:
|
||||||
return self.loss_func(x_diff, target_diff, length-1)
|
|
||||||
# if loss masking is not enabled
|
|
||||||
return self.loss_func(x_diff, target_diff)
|
return self.loss_func(x_diff, target_diff)
|
||||||
|
return self.loss_func(x_diff, target_diff, length-1)
|
||||||
|
|
||||||
|
|
||||||
class GuidedAttentionLoss(torch.nn.Module):
|
class GuidedAttentionLoss(torch.nn.Module):
|
||||||
|
@ -188,6 +235,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TacotronLoss(torch.nn.Module):
|
class TacotronLoss(torch.nn.Module):
|
||||||
|
"""Collection of Tacotron set-up based on provided config."""
|
||||||
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
||||||
super(TacotronLoss, self).__init__()
|
super(TacotronLoss, self).__init__()
|
||||||
self.stopnet_pos_weight = stopnet_pos_weight
|
self.stopnet_pos_weight = stopnet_pos_weight
|
||||||
|
@ -195,6 +243,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
self.diff_spec_alpha = c.diff_spec_alpha
|
self.diff_spec_alpha = c.diff_spec_alpha
|
||||||
self.decoder_alpha = c.decoder_loss_alpha
|
self.decoder_alpha = c.decoder_loss_alpha
|
||||||
self.postnet_alpha = c.postnet_loss_alpha
|
self.postnet_alpha = c.postnet_loss_alpha
|
||||||
|
self.ssim_alpha = c.ssim_alpha
|
||||||
self.config = c
|
self.config = c
|
||||||
|
|
||||||
# postnet and decoder loss
|
# postnet and decoder loss
|
||||||
|
@ -205,12 +254,15 @@ class TacotronLoss(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.criterion = nn.L1Loss() if c.model in ["Tacotron"
|
self.criterion = nn.L1Loss() if c.model in ["Tacotron"
|
||||||
] else nn.MSELoss()
|
] else nn.MSELoss()
|
||||||
# differential spectral loss
|
|
||||||
if c.diff_spec_alpha > 0:
|
|
||||||
self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion)
|
|
||||||
# guided attention loss
|
# guided attention loss
|
||||||
if c.ga_alpha > 0:
|
if c.ga_alpha > 0:
|
||||||
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
|
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
|
||||||
|
# differential spectral loss
|
||||||
|
if c.postnet_diff_spec_alpha > 0 or c.decoder_diff_spec_alpha > 0:
|
||||||
|
self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion)
|
||||||
|
# ssim loss
|
||||||
|
if c.postnet_ssim_alpha > 0 or c.decoder_ssim_alpha > 0:
|
||||||
|
self.criterion_ssim = SSIMLoss()
|
||||||
# stopnet loss
|
# stopnet loss
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
self.criterion_st = BCELossMasked(
|
self.criterion_st = BCELossMasked(
|
||||||
|
@ -221,6 +273,9 @@ class TacotronLoss(torch.nn.Module):
|
||||||
alignments, alignment_lens, alignments_backwards, input_lens):
|
alignments, alignment_lens, alignments_backwards, input_lens):
|
||||||
|
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
# remove lengths if no masking is applied
|
||||||
|
if not self.config.loss_masking:
|
||||||
|
output_lens = None
|
||||||
# decoder and postnet losses
|
# decoder and postnet losses
|
||||||
if self.config.loss_masking:
|
if self.config.loss_masking:
|
||||||
if self.decoder_alpha > 0:
|
if self.decoder_alpha > 0:
|
||||||
|
@ -285,11 +340,30 @@ class TacotronLoss(torch.nn.Module):
|
||||||
loss += ga_loss * self.ga_alpha
|
loss += ga_loss * self.ga_alpha
|
||||||
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
||||||
|
|
||||||
# differential spectral loss
|
# decoder differential spectral loss
|
||||||
if self.config.diff_spec_alpha > 0:
|
if self.config.decoder_diff_spec_alpha > 0:
|
||||||
diff_spec_loss = self.criterion_diff_spec(postnet_output, mel_input, output_lens)
|
decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens)
|
||||||
loss += diff_spec_loss * self.diff_spec_alpha
|
loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha
|
||||||
return_dict['diff_spec_loss'] = diff_spec_loss
|
return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss
|
||||||
|
|
||||||
|
# postnet differential spectral loss
|
||||||
|
if self.config.postnet_diff_spec_alpha > 0:
|
||||||
|
postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, mel_input, output_lens)
|
||||||
|
loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha
|
||||||
|
return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss
|
||||||
|
|
||||||
|
# decoder ssim loss
|
||||||
|
if self.config.decoder_ssim_alpha > 0:
|
||||||
|
decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens)
|
||||||
|
loss += decoder_ssim_loss * self.postnet_ssim_alpha
|
||||||
|
return_dict['decoder_ssim_loss'] = decoder_ssim_loss
|
||||||
|
|
||||||
|
# postnet ssim loss
|
||||||
|
if self.config.postnet_ssim_alpha > 0:
|
||||||
|
postnet_ssim_loss = self.criterion_ssim(postnet_output, mel_input, output_lens)
|
||||||
|
loss += postnet_ssim_loss * self.postnet_ssim_alpha
|
||||||
|
return_dict['postnet_ssim_loss'] = postnet_ssim_loss
|
||||||
|
|
||||||
return_dict['loss'] = loss
|
return_dict['loss'] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
|
@ -178,10 +178,19 @@ def check_config_tts(c):
|
||||||
check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||||
check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||||
check_argument('gradual_training', c, restricted=False, val_type=list)
|
check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||||
check_argument('loss_masking', c, restricted=True, val_type=bool)
|
|
||||||
check_argument('apex_amp_level', c, restricted=False, val_type=str)
|
check_argument('apex_amp_level', c, restricted=False, val_type=str)
|
||||||
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||||
|
|
||||||
|
# loss parameters
|
||||||
|
check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
|
||||||
# validation parameters
|
# validation parameters
|
||||||
check_argument('run_eval', c, restricted=True, val_type=bool)
|
check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||||
check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.tts.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
from TTS.tts.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
from TTS.tts.layers.losses import L1LossMasked
|
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
@ -149,3 +149,72 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
|
class SSIMLossTests(unittest.TestCase):
|
||||||
|
def test_in_out(self): #pylint: disable=no-self-use
|
||||||
|
# test input == target
|
||||||
|
layer = SSIMLoss()
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-4 , "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = (
|
||||||
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = (
|
||||||
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# seq_len_norm = True
|
||||||
|
# test input == target
|
||||||
|
layer = L1LossMasked(seq_len_norm=True)
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = (
|
||||||
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = (
|
||||||
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue