From 2cf89b88c986e3bfcd0e0fd976571e7cce6f1de5 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 12 Jul 2022 14:12:57 +0200 Subject: [PATCH] Make style --- TTS/tts/layers/losses.py | 15 +++- TTS/tts/utils/ssim.py | 153 +++++++++++++++++++++------------ tests/tts_tests/test_losses.py | 16 ++-- 3 files changed, 119 insertions(+), 65 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 50c61d67..e43dd6b1 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -101,6 +101,7 @@ def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True) return (x - minimum) / (maximum - minimum + 1e-8) + class SSIMLoss(torch.nn.Module): """SSIM loss as (1 - SSIM) SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity @@ -154,7 +155,15 @@ class AttentionEntropyLoss(nn.Module): class BCELossMasked(nn.Module): - def __init__(self, pos_weight:float=None): + """BCE loss with masking. + + Used mainly for stopnet in autoregressive models. + + Args: + pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None. + """ + + def __init__(self, pos_weight: float = None): super().__init__() self.pos_weight = torch.tensor([pos_weight]) @@ -181,7 +190,9 @@ class BCELossMasked(nn.Module): # mask: (batch, max_len, 1) mask = sequence_mask(sequence_length=length, max_len=target.size(1)) num_items = mask.sum() - loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum") + loss = functional.binary_cross_entropy_with_logits( + x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum" + ) else: loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") num_items = torch.numel(x) diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index c83504cb..70618483 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -1,35 +1,35 @@ # Adopted from https://github.com/photosynthesis-team/piq -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch.nn.modules.loss import _Loss -def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: +def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: r"""Reduce input in batch dimension if needed. Args: x: Tensor with shape (N, *). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` """ - if reduction == 'none': + if reduction == "none": return x - elif reduction == 'mean': + elif reduction == "mean": return x.mean(dim=0) - elif reduction == 'sum': + elif reduction == "sum": return x.sum(dim=0) else: raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") def _validate_input( - tensors: List[torch.Tensor], - dim_range: Tuple[int, int] = (0, -1), - data_range: Tuple[float, float] = (0., -1.), - # size_dim_range: Tuple[float, float] = (0., -1.), - size_range: Optional[Tuple[int, int]] = None, + tensors: List[torch.Tensor], + dim_range: Tuple[int, int] = (0, -1), + data_range: Tuple[float, float] = (0.0, -1.0), + # size_dim_range: Tuple[float, float] = (0., -1.), + size_range: Optional[Tuple[int, int]] = None, ) -> None: r"""Check that input(-s) satisfies the requirements Args: @@ -45,26 +45,26 @@ def _validate_input( x = tensors[0] for t in tensors: - assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}' - assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}' + assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}" + assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}" if size_range is None: - assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}' + assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}" else: - assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \ - f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}' + assert ( + t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]] + ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}" if dim_range[0] == dim_range[1]: - assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}' + assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}" elif dim_range[0] < dim_range[1]: - assert dim_range[0] <= t.dim() <= dim_range[1], \ - f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}' + assert ( + dim_range[0] <= t.dim() <= dim_range[1] + ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}" if data_range[0] < data_range[1]: - assert data_range[0] <= t.min(), \ - f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}' - assert t.max() <= data_range[1], \ - f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}' + assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}" + assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}" def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor: @@ -76,18 +76,27 @@ def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor: gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) """ coords = torch.arange(kernel_size, dtype=torch.float32) - coords -= (kernel_size - 1) / 2. + coords -= (kernel_size - 1) / 2.0 - g = coords ** 2 - g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() + g = coords**2 + g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp() g /= g.sum() return g.unsqueeze(0) -def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, - data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, - downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]: +def ssim( + x: torch.Tensor, + y: torch.Tensor, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + data_range: Union[int, float] = 1.0, + reduction: str = "mean", + full: bool = False, + downsample: bool = True, + k1: float = 0.01, + k2: float = 0.03, +) -> List[torch.Tensor]: r"""Interface of Structural Similarity (SSIM) index. Inputs supposed to be in range ``[0, data_range]``. To match performance with skimage and tensorflow set ``'downsample' = True``. @@ -117,7 +126,7 @@ def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, DOI: `10.1109/TIP.2003.819861` """ - assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) x = x / float(data_range) @@ -199,10 +208,18 @@ class SSIMLoss(_Loss): https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, DOI:`10.1109/TIP.2003.819861` """ - __constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] + __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] - def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, - downsample: bool = True, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None: + def __init__( + self, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + downsample: bool = True, + reduction: str = "mean", + data_range: Union[int, float] = 1.0, + ) -> None: super().__init__() # Generic loss parameters. @@ -213,7 +230,7 @@ class SSIMLoss(_Loss): # This check might look redundant because kernel size is checked within the ssim function anyway. # However, this check allows to fail fast when the loss is being initialised and training has not been started. - assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" self.kernel_sigma = kernel_sigma self.k1 = k1 self.k2 = k2 @@ -232,14 +249,29 @@ class SSIMLoss(_Loss): complex value is returned as a tensor of size 2. """ - score = ssim(x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, downsample=self.downsample, - data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2) + score = ssim( + x=x, + y=y, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + downsample=self.downsample, + data_range=self.data_range, + reduction=self.reduction, + full=False, + k1=self.k1, + k2=self.k2, + ) return torch.ones_like(score) - score -def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, - data_range: Union[float, int] = 1., k1: float = 0.01, - k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +def _ssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + data_range: Union[float, int] = 1.0, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Calculate Structural Similarity (SSIM) index for X and Y per channel. Args: @@ -255,37 +287,44 @@ def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, Full Value of Structural Similarity (SSIM) index. """ if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): - raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. ' - f'Kernel size: {kernel.size()}') + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) - c1 = k1 ** 2 - c2 = k2 ** 2 + c1 = k1**2 + c2 = k2**2 n_channels = x.size(1) mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels) mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx = mu_x ** 2 - mu_yy = mu_y ** 2 + mu_xx = mu_x**2 + mu_yy = mu_y**2 mu_xy = mu_x * mu_y - sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx - sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy + sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx + sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy # Contrast sensitivity (CS) with alpha = beta = gamma = 1. - cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) + cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) # Structural similarity (SSIM) - ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs + ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs ssim_val = ss.mean(dim=(-1, -2)) cs = cs.mean(dim=(-1, -2)) return ssim_val, cs -def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, - data_range: Union[float, int] = 1., k1: float = 0.01, - k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +def _ssim_per_channel_complex( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + data_range: Union[float, int] = 1.0, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. Args: @@ -302,11 +341,13 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te """ n_channels = x.size(1) if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): - raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. ' - f'Kernel size: {kernel.size()}') + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) - c1 = k1 ** 2 - c2 = k2 ** 2 + c1 = k1**2 + c2 = k2**2 x_real = x[..., 0] x_imag = x[..., 1] @@ -344,4 +385,4 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te ssim_val = ssim_map.mean(dim=(-2, -3)) cs = cs_map.mean(dim=(-2, -3)) - return ssim_val, cs \ No newline at end of file + return ssim_val, cs diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py index e4652408..e7999daa 100644 --- a/tests/tts_tests/test_losses.py +++ b/tests/tts_tests/test_losses.py @@ -1,9 +1,10 @@ import unittest + import torch as T from torch.nn import functional +from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss from TTS.tts.utils.helpers import sequence_mask -from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked class L1LossMaskedTests(unittest.TestCase): @@ -134,7 +135,6 @@ class MSELossMaskedTests(unittest.TestCase): assert output.item() == 0, "0 vs {}".format(output.item()) - class SSIMLossTests(unittest.TestCase): def test_in_out(self): # pylint: disable=no-self-use # test input == target @@ -150,7 +150,7 @@ class SSIMLossTests(unittest.TestCase): dummy_input = dummy_input.reshape(4, 57, 128).float() dummy_target = T.arange(-4 * 57 * 128, 0) dummy_target = dummy_target.reshape(4, 57, 128).float() - dummy_target = (-dummy_target) + dummy_target = -dummy_target dummy_length = (T.ones(4) * 58).long() output = layer(dummy_input, dummy_target, dummy_length) @@ -210,11 +210,13 @@ class BCELossTest(unittest.TestCase): length = T.tensor([95]) mask = sequence_mask(length, 100) pos_weight = T.tensor([5.0]) - target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame + target = ( + 1.0 - sequence_mask(length - 1, 100).float() + ) # [0, 0, .... 1, 1] where the first 1 is the last mel frame true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target - zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding - early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping - late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping + zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding + early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping + late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping loss = layer(true_x, target, length) self.assertEqual(loss.item(), 0.0)