Make style

This commit is contained in:
Eren G??lge 2022-07-12 14:12:57 +02:00
parent a6f73a18cb
commit 2cf89b88c9
3 changed files with 119 additions and 65 deletions

View File

@ -101,6 +101,7 @@ def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True)
return (x - minimum) / (maximum - minimum + 1e-8)
class SSIMLoss(torch.nn.Module):
"""SSIM loss as (1 - SSIM)
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
@ -154,7 +155,15 @@ class AttentionEntropyLoss(nn.Module):
class BCELossMasked(nn.Module):
def __init__(self, pos_weight:float=None):
"""BCE loss with masking.
Used mainly for stopnet in autoregressive models.
Args:
pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None.
"""
def __init__(self, pos_weight: float = None):
super().__init__()
self.pos_weight = torch.tensor([pos_weight])
@ -181,7 +190,9 @@ class BCELossMasked(nn.Module):
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
num_items = mask.sum()
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
loss = functional.binary_cross_entropy_with_logits(
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
)
else:
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
num_items = torch.numel(x)

View File

@ -1,35 +1,35 @@
# Adopted from https://github.com/photosynthesis-team/piq
from typing import Optional, Tuple, Union, List
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
r"""Reduce input in batch dimension if needed.
Args:
x: Tensor with shape (N, *).
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
"""
if reduction == 'none':
if reduction == "none":
return x
elif reduction == 'mean':
elif reduction == "mean":
return x.mean(dim=0)
elif reduction == 'sum':
elif reduction == "sum":
return x.sum(dim=0)
else:
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
def _validate_input(
tensors: List[torch.Tensor],
dim_range: Tuple[int, int] = (0, -1),
data_range: Tuple[float, float] = (0., -1.),
# size_dim_range: Tuple[float, float] = (0., -1.),
size_range: Optional[Tuple[int, int]] = None,
tensors: List[torch.Tensor],
dim_range: Tuple[int, int] = (0, -1),
data_range: Tuple[float, float] = (0.0, -1.0),
# size_dim_range: Tuple[float, float] = (0., -1.),
size_range: Optional[Tuple[int, int]] = None,
) -> None:
r"""Check that input(-s) satisfies the requirements
Args:
@ -45,26 +45,26 @@ def _validate_input(
x = tensors[0]
for t in tensors:
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}"
assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}"
if size_range is None:
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}"
else:
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
assert (
t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]]
), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
if dim_range[0] == dim_range[1]:
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
elif dim_range[0] < dim_range[1]:
assert dim_range[0] <= t.dim() <= dim_range[1], \
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
assert (
dim_range[0] <= t.dim() <= dim_range[1]
), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
if data_range[0] < data_range[1]:
assert data_range[0] <= t.min(), \
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
assert t.max() <= data_range[1], \
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}"
def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
@ -76,18 +76,27 @@ def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
"""
coords = torch.arange(kernel_size, dtype=torch.float32)
coords -= (kernel_size - 1) / 2.
coords -= (kernel_size - 1) / 2.0
g = coords ** 2
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
g = coords**2
g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp()
g /= g.sum()
return g.unsqueeze(0)
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]:
def ssim(
x: torch.Tensor,
y: torch.Tensor,
kernel_size: int = 11,
kernel_sigma: float = 1.5,
data_range: Union[int, float] = 1.0,
reduction: str = "mean",
full: bool = False,
downsample: bool = True,
k1: float = 0.01,
k2: float = 0.03,
) -> List[torch.Tensor]:
r"""Interface of Structural Similarity (SSIM) index.
Inputs supposed to be in range ``[0, data_range]``.
To match performance with skimage and tensorflow set ``'downsample' = True``.
@ -117,7 +126,7 @@ def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma:
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
DOI: `10.1109/TIP.2003.819861`
"""
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
_validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))
x = x / float(data_range)
@ -199,10 +208,18 @@ class SSIMLoss(_Loss):
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
DOI:`10.1109/TIP.2003.819861`
"""
__constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
__constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"]
def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
downsample: bool = True, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None:
def __init__(
self,
kernel_size: int = 11,
kernel_sigma: float = 1.5,
k1: float = 0.01,
k2: float = 0.03,
downsample: bool = True,
reduction: str = "mean",
data_range: Union[int, float] = 1.0,
) -> None:
super().__init__()
# Generic loss parameters.
@ -213,7 +230,7 @@ class SSIMLoss(_Loss):
# This check might look redundant because kernel size is checked within the ssim function anyway.
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
self.kernel_sigma = kernel_sigma
self.k1 = k1
self.k2 = k2
@ -232,14 +249,29 @@ class SSIMLoss(_Loss):
complex value is returned as a tensor of size 2.
"""
score = ssim(x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, downsample=self.downsample,
data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2)
score = ssim(
x=x,
y=y,
kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
downsample=self.downsample,
data_range=self.data_range,
reduction=self.reduction,
full=False,
k1=self.k1,
k2=self.k2,
)
return torch.ones_like(score) - score
def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
data_range: Union[float, int] = 1., k1: float = 0.01,
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
def _ssim_per_channel(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
data_range: Union[float, int] = 1.0,
k1: float = 0.01,
k2: float = 0.03,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
Args:
@ -255,37 +287,44 @@ def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
Full Value of Structural Similarity (SSIM) index.
"""
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
f'Kernel size: {kernel.size()}')
raise ValueError(
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
f"Kernel size: {kernel.size()}"
)
c1 = k1 ** 2
c2 = k2 ** 2
c1 = k1**2
c2 = k2**2
n_channels = x.size(1)
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_xx = mu_x ** 2
mu_yy = mu_y ** 2
mu_xx = mu_x**2
mu_yy = mu_y**2
mu_xy = mu_x * mu_y
sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
# Structural similarity (SSIM)
ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
ssim_val = ss.mean(dim=(-1, -2))
cs = cs.mean(dim=(-1, -2))
return ssim_val, cs
def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
data_range: Union[float, int] = 1., k1: float = 0.01,
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
def _ssim_per_channel_complex(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
data_range: Union[float, int] = 1.0,
k1: float = 0.01,
k2: float = 0.03,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
Args:
@ -302,11 +341,13 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te
"""
n_channels = x.size(1)
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
f'Kernel size: {kernel.size()}')
raise ValueError(
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
f"Kernel size: {kernel.size()}"
)
c1 = k1 ** 2
c2 = k2 ** 2
c1 = k1**2
c2 = k2**2
x_real = x[..., 0]
x_imag = x[..., 1]
@ -344,4 +385,4 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te
ssim_val = ssim_map.mean(dim=(-2, -3))
cs = cs_map.mean(dim=(-2, -3))
return ssim_val, cs
return ssim_val, cs

View File

@ -1,9 +1,10 @@
import unittest
import torch as T
from torch.nn import functional
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked
class L1LossMaskedTests(unittest.TestCase):
@ -134,7 +135,6 @@ class MSELossMaskedTests(unittest.TestCase):
assert output.item() == 0, "0 vs {}".format(output.item())
class SSIMLossTests(unittest.TestCase):
def test_in_out(self): # pylint: disable=no-self-use
# test input == target
@ -150,7 +150,7 @@ class SSIMLossTests(unittest.TestCase):
dummy_input = dummy_input.reshape(4, 57, 128).float()
dummy_target = T.arange(-4 * 57 * 128, 0)
dummy_target = dummy_target.reshape(4, 57, 128).float()
dummy_target = (-dummy_target)
dummy_target = -dummy_target
dummy_length = (T.ones(4) * 58).long()
output = layer(dummy_input, dummy_target, dummy_length)
@ -210,11 +210,13 @@ class BCELossTest(unittest.TestCase):
length = T.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = T.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
target = (
1.0 - sequence_mask(length - 1, 100).float()
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding
early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping
late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping
loss = layer(true_x, target, length)
self.assertEqual(loss.item(), 0.0)