From 48a4f3647fe29abdc2b770afee43341f650e1007 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 12 Jul 2022 14:58:26 +0200 Subject: [PATCH] Make lint --- TTS/tts/layers/losses.py | 4 ++-- TTS/tts/utils/ssim.py | 9 ++------- tests/tts_tests/test_losses.py | 4 +--- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e43dd6b1..4430d9ff 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module): if ssim_loss.item() > 1.0: print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") - ssim_loss == 1.0 + ssim_loss = torch.tensor([1.0]) if ssim_loss.item() < 0.0: print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") - ssim_loss == 0.0 + ssim_loss = torch.tensor([0.0]) return ssim_loss diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 70618483..2bca1be5 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -20,8 +20,7 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: return x.mean(dim=0) elif reduction == "sum": return x.sum(dim=0) - else: - raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") + raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") def _validate_input( @@ -140,7 +139,7 @@ def ssim( kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel - ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) + ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2) ssim_val = ssim_map.mean(1) cs = cs_map.mean(1) @@ -268,7 +267,6 @@ def _ssim_per_channel( x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, - data_range: Union[float, int] = 1.0, k1: float = 0.01, k2: float = 0.03, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -278,7 +276,6 @@ def _ssim_per_channel( x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. kernel: 2D Gaussian kernel. - data_range: Maximum value range of images (usually 1.0 or 255). k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. @@ -321,7 +318,6 @@ def _ssim_per_channel_complex( x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, - data_range: Union[float, int] = 1.0, k1: float = 0.01, k2: float = 0.03, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -331,7 +327,6 @@ def _ssim_per_channel_complex( x: An input tensor. Shape :math:`(N, C, H, W, 2)`. y: A target tensor. Shape :math:`(N, C, H, W, 2)`. kernel: 2-D gauss kernel. - data_range: Maximum value range of images (usually 1.0 or 255). k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py index e7999daa..522b7bb1 100644 --- a/tests/tts_tests/test_losses.py +++ b/tests/tts_tests/test_losses.py @@ -1,7 +1,6 @@ import unittest import torch as T -from torch.nn import functional from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss from TTS.tts.utils.helpers import sequence_mask @@ -208,8 +207,6 @@ class BCELossTest(unittest.TestCase): layer = BCELossMasked(pos_weight=5.0) length = T.tensor([95]) - mask = sequence_mask(length, 100) - pos_weight = T.tensor([5.0]) target = ( 1.0 - sequence_mask(length - 1, 100).float() ) # [0, 0, .... 1, 1] where the first 1 is the last mel frame @@ -236,6 +233,7 @@ class BCELossTest(unittest.TestCase): self.assertEqual(loss.item(), 0.0) # when pos_weight < 1 overweight the early stopping loss + loss_early = layer(early_x, target, length) loss_late = layer(late_x, target, length) self.assertGreater(loss_early.item(), loss_late.item())