mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
a6f73a18cb
commit
2cf89b88c9
|
@ -101,6 +101,7 @@ def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|||
minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True)
|
||||
return (x - minimum) / (maximum - minimum + 1e-8)
|
||||
|
||||
|
||||
class SSIMLoss(torch.nn.Module):
|
||||
"""SSIM loss as (1 - SSIM)
|
||||
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
|
||||
|
@ -154,7 +155,15 @@ class AttentionEntropyLoss(nn.Module):
|
|||
|
||||
|
||||
class BCELossMasked(nn.Module):
|
||||
def __init__(self, pos_weight:float=None):
|
||||
"""BCE loss with masking.
|
||||
|
||||
Used mainly for stopnet in autoregressive models.
|
||||
|
||||
Args:
|
||||
pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, pos_weight: float = None):
|
||||
super().__init__()
|
||||
self.pos_weight = torch.tensor([pos_weight])
|
||||
|
||||
|
@ -181,7 +190,9 @@ class BCELossMasked(nn.Module):
|
|||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
num_items = mask.sum()
|
||||
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
|
||||
)
|
||||
else:
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
num_items = torch.numel(x)
|
||||
|
|
|
@ -1,35 +1,35 @@
|
|||
# Adopted from https://github.com/photosynthesis-team/piq
|
||||
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
|
||||
def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
||||
r"""Reduce input in batch dimension if needed.
|
||||
Args:
|
||||
x: Tensor with shape (N, *).
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
||||
"""
|
||||
if reduction == 'none':
|
||||
if reduction == "none":
|
||||
return x
|
||||
elif reduction == 'mean':
|
||||
elif reduction == "mean":
|
||||
return x.mean(dim=0)
|
||||
elif reduction == 'sum':
|
||||
elif reduction == "sum":
|
||||
return x.sum(dim=0)
|
||||
else:
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
|
||||
|
||||
def _validate_input(
|
||||
tensors: List[torch.Tensor],
|
||||
dim_range: Tuple[int, int] = (0, -1),
|
||||
data_range: Tuple[float, float] = (0., -1.),
|
||||
# size_dim_range: Tuple[float, float] = (0., -1.),
|
||||
size_range: Optional[Tuple[int, int]] = None,
|
||||
tensors: List[torch.Tensor],
|
||||
dim_range: Tuple[int, int] = (0, -1),
|
||||
data_range: Tuple[float, float] = (0.0, -1.0),
|
||||
# size_dim_range: Tuple[float, float] = (0., -1.),
|
||||
size_range: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
r"""Check that input(-s) satisfies the requirements
|
||||
Args:
|
||||
|
@ -45,26 +45,26 @@ def _validate_input(
|
|||
x = tensors[0]
|
||||
|
||||
for t in tensors:
|
||||
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
|
||||
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
|
||||
assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}"
|
||||
assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}"
|
||||
|
||||
if size_range is None:
|
||||
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
|
||||
assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}"
|
||||
else:
|
||||
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
|
||||
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
|
||||
assert (
|
||||
t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]]
|
||||
), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
|
||||
|
||||
if dim_range[0] == dim_range[1]:
|
||||
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
|
||||
assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
|
||||
elif dim_range[0] < dim_range[1]:
|
||||
assert dim_range[0] <= t.dim() <= dim_range[1], \
|
||||
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
|
||||
assert (
|
||||
dim_range[0] <= t.dim() <= dim_range[1]
|
||||
), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
|
||||
|
||||
if data_range[0] < data_range[1]:
|
||||
assert data_range[0] <= t.min(), \
|
||||
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
|
||||
assert t.max() <= data_range[1], \
|
||||
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
|
||||
assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
|
||||
assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}"
|
||||
|
||||
|
||||
def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
|
||||
|
@ -76,18 +76,27 @@ def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
|
|||
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
||||
"""
|
||||
coords = torch.arange(kernel_size, dtype=torch.float32)
|
||||
coords -= (kernel_size - 1) / 2.
|
||||
coords -= (kernel_size - 1) / 2.0
|
||||
|
||||
g = coords ** 2
|
||||
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
|
||||
g = coords**2
|
||||
g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp()
|
||||
|
||||
g /= g.sum()
|
||||
return g.unsqueeze(0)
|
||||
|
||||
|
||||
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
|
||||
data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
|
||||
downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]:
|
||||
def ssim(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel_size: int = 11,
|
||||
kernel_sigma: float = 1.5,
|
||||
data_range: Union[int, float] = 1.0,
|
||||
reduction: str = "mean",
|
||||
full: bool = False,
|
||||
downsample: bool = True,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> List[torch.Tensor]:
|
||||
r"""Interface of Structural Similarity (SSIM) index.
|
||||
Inputs supposed to be in range ``[0, data_range]``.
|
||||
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||
|
@ -117,7 +126,7 @@ def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma:
|
|||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI: `10.1109/TIP.2003.819861`
|
||||
"""
|
||||
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
|
||||
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||
_validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))
|
||||
|
||||
x = x / float(data_range)
|
||||
|
@ -199,10 +208,18 @@ class SSIMLoss(_Loss):
|
|||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI:`10.1109/TIP.2003.819861`
|
||||
"""
|
||||
__constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
|
||||
__constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"]
|
||||
|
||||
def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
|
||||
downsample: bool = True, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int = 11,
|
||||
kernel_sigma: float = 1.5,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
downsample: bool = True,
|
||||
reduction: str = "mean",
|
||||
data_range: Union[int, float] = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Generic loss parameters.
|
||||
|
@ -213,7 +230,7 @@ class SSIMLoss(_Loss):
|
|||
|
||||
# This check might look redundant because kernel size is checked within the ssim function anyway.
|
||||
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
|
||||
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
|
||||
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||
self.kernel_sigma = kernel_sigma
|
||||
self.k1 = k1
|
||||
self.k2 = k2
|
||||
|
@ -232,14 +249,29 @@ class SSIMLoss(_Loss):
|
|||
complex value is returned as a tensor of size 2.
|
||||
"""
|
||||
|
||||
score = ssim(x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, downsample=self.downsample,
|
||||
data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2)
|
||||
score = ssim(
|
||||
x=x,
|
||||
y=y,
|
||||
kernel_size=self.kernel_size,
|
||||
kernel_sigma=self.kernel_sigma,
|
||||
downsample=self.downsample,
|
||||
data_range=self.data_range,
|
||||
reduction=self.reduction,
|
||||
full=False,
|
||||
k1=self.k1,
|
||||
k2=self.k2,
|
||||
)
|
||||
return torch.ones_like(score) - score
|
||||
|
||||
|
||||
def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1., k1: float = 0.01,
|
||||
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
def _ssim_per_channel(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1.0,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
|
||||
|
||||
Args:
|
||||
|
@ -255,37 +287,44 @@ def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
|
|||
Full Value of Structural Similarity (SSIM) index.
|
||||
"""
|
||||
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
|
||||
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
|
||||
f'Kernel size: {kernel.size()}')
|
||||
raise ValueError(
|
||||
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||
f"Kernel size: {kernel.size()}"
|
||||
)
|
||||
|
||||
c1 = k1 ** 2
|
||||
c2 = k2 ** 2
|
||||
c1 = k1**2
|
||||
c2 = k2**2
|
||||
n_channels = x.size(1)
|
||||
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
|
||||
mu_xx = mu_x ** 2
|
||||
mu_yy = mu_y ** 2
|
||||
mu_xx = mu_x**2
|
||||
mu_yy = mu_y**2
|
||||
mu_xy = mu_x * mu_y
|
||||
|
||||
sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
|
||||
sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
|
||||
sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
|
||||
sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
|
||||
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
|
||||
|
||||
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
|
||||
cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
|
||||
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
|
||||
|
||||
# Structural similarity (SSIM)
|
||||
ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
|
||||
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
|
||||
|
||||
ssim_val = ss.mean(dim=(-1, -2))
|
||||
cs = cs.mean(dim=(-1, -2))
|
||||
return ssim_val, cs
|
||||
|
||||
|
||||
def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1., k1: float = 0.01,
|
||||
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
def _ssim_per_channel_complex(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1.0,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
|
||||
|
||||
Args:
|
||||
|
@ -302,11 +341,13 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te
|
|||
"""
|
||||
n_channels = x.size(1)
|
||||
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
|
||||
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
|
||||
f'Kernel size: {kernel.size()}')
|
||||
raise ValueError(
|
||||
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||
f"Kernel size: {kernel.size()}"
|
||||
)
|
||||
|
||||
c1 = k1 ** 2
|
||||
c2 = k2 ** 2
|
||||
c1 = k1**2
|
||||
c2 = k2**2
|
||||
|
||||
x_real = x[..., 0]
|
||||
x_imag = x[..., 1]
|
||||
|
@ -344,4 +385,4 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te
|
|||
ssim_val = ssim_map.mean(dim=(-2, -3))
|
||||
cs = cs_map.mean(dim=(-2, -3))
|
||||
|
||||
return ssim_val, cs
|
||||
return ssim_val, cs
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import unittest
|
||||
|
||||
import torch as T
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked
|
||||
|
||||
|
||||
class L1LossMaskedTests(unittest.TestCase):
|
||||
|
@ -134,7 +135,6 @@ class MSELossMaskedTests(unittest.TestCase):
|
|||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
|
||||
class SSIMLossTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
|
@ -150,7 +150,7 @@ class SSIMLossTests(unittest.TestCase):
|
|||
dummy_input = dummy_input.reshape(4, 57, 128).float()
|
||||
dummy_target = T.arange(-4 * 57 * 128, 0)
|
||||
dummy_target = dummy_target.reshape(4, 57, 128).float()
|
||||
dummy_target = (-dummy_target)
|
||||
dummy_target = -dummy_target
|
||||
|
||||
dummy_length = (T.ones(4) * 58).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
|
@ -210,11 +210,13 @@ class BCELossTest(unittest.TestCase):
|
|||
length = T.tensor([95])
|
||||
mask = sequence_mask(length, 100)
|
||||
pos_weight = T.tensor([5.0])
|
||||
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||
target = (
|
||||
1.0 - sequence_mask(length - 1, 100).float()
|
||||
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
|
||||
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
|
||||
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
|
||||
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
|
||||
zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding
|
||||
early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping
|
||||
late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping
|
||||
|
||||
loss = layer(true_x, target, length)
|
||||
self.assertEqual(loss.item(), 0.0)
|
||||
|
|
Loading…
Reference in New Issue