mirror of https://github.com/coqui-ai/TTS.git
Fix SSIM loss
This commit is contained in:
parent
f1e35596e8
commit
c17ff17a18
|
@ -7,7 +7,7 @@ from torch import nn
|
|||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
||||
|
||||
|
@ -91,30 +91,54 @@ class MSELossMasked(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Min-Max normalize tensor through first dimension
|
||||
Shapes:
|
||||
- x: :math:`[B, D1, D2]`
|
||||
- m: :math:`[B, D1, 1]`
|
||||
"""
|
||||
maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True)
|
||||
minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True)
|
||||
return (x - minimum) / (maximum - minimum + 1e-8)
|
||||
|
||||
class SSIMLoss(torch.nn.Module):
|
||||
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
|
||||
"""SSIM loss as (1 - SSIM)
|
||||
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_func = ssim
|
||||
self.loss_func = _SSIMLoss()
|
||||
|
||||
def forward(self, y_hat, y, length=None):
|
||||
def forward(self, y_hat, y, length):
|
||||
"""
|
||||
Args:
|
||||
y_hat (tensor): model prediction values.
|
||||
y (tensor): target values.
|
||||
length (tensor): length of each sample in a batch.
|
||||
length (tensor): length of each sample in a batch for masking.
|
||||
|
||||
Shapes:
|
||||
y_hat: B x T X D
|
||||
y: B x T x D
|
||||
length: B
|
||||
|
||||
Returns:
|
||||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
if length is not None:
|
||||
m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
|
||||
y_hat, y = y_hat * m, y * m
|
||||
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2)
|
||||
y_norm = sample_wise_min_max(y, mask)
|
||||
y_hat_norm = sample_wise_min_max(y_hat, mask)
|
||||
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
|
||||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
ssim_loss == 1.0
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
ssim_loss == 0.0
|
||||
|
||||
return ssim_loss
|
||||
|
||||
|
||||
class AttentionEntropyLoss(nn.Module):
|
||||
|
|
|
@ -1,73 +1,347 @@
|
|||
# taken from https://github.com/Po-Hsun-Su/pytorch-ssim
|
||||
# Adopted from https://github.com/photosynthesis-team/piq
|
||||
|
||||
from math import exp
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
|
||||
r"""Reduce input in batch dimension if needed.
|
||||
Args:
|
||||
x: Tensor with shape (N, *).
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
||||
"""
|
||||
if reduction == 'none':
|
||||
return x
|
||||
elif reduction == 'mean':
|
||||
return x.mean(dim=0)
|
||||
elif reduction == 'sum':
|
||||
return x.sum(dim=0)
|
||||
else:
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
def _validate_input(
|
||||
tensors: List[torch.Tensor],
|
||||
dim_range: Tuple[int, int] = (0, -1),
|
||||
data_range: Tuple[float, float] = (0., -1.),
|
||||
# size_dim_range: Tuple[float, float] = (0., -1.),
|
||||
size_range: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
r"""Check that input(-s) satisfies the requirements
|
||||
Args:
|
||||
tensors: Tensors to check
|
||||
dim_range: Allowed number of dimensions. (min, max)
|
||||
data_range: Allowed range of values in tensors. (min, max)
|
||||
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
|
||||
"""
|
||||
|
||||
if not __debug__:
|
||||
return
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
x = tensors[0]
|
||||
|
||||
# TODO: check if you need AMP disabled
|
||||
# with torch.cuda.amp.autocast(enabled=False):
|
||||
mu1_sq = mu1.float().pow(2)
|
||||
mu2_sq = mu2.float().pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
for t in tensors:
|
||||
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
|
||||
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
if size_range is None:
|
||||
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
window = window.type_as(img1)
|
||||
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
|
||||
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
if dim_range[0] == dim_range[1]:
|
||||
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
|
||||
elif dim_range[0] < dim_range[1]:
|
||||
assert dim_range[0] <= t.dim() <= dim_range[1], \
|
||||
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
||||
if data_range[0] < data_range[1]:
|
||||
assert data_range[0] <= t.min(), \
|
||||
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
|
||||
assert t.max() <= data_range[1], \
|
||||
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel).type_as(img1)
|
||||
window = window.type_as(img1)
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||
def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
|
||||
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
|
||||
Args:
|
||||
size: Size of the kernel
|
||||
sigma: Std of the distribution
|
||||
Returns:
|
||||
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
||||
"""
|
||||
coords = torch.arange(kernel_size, dtype=torch.float32)
|
||||
coords -= (kernel_size - 1) / 2.
|
||||
|
||||
g = coords ** 2
|
||||
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
|
||||
|
||||
g /= g.sum()
|
||||
return g.unsqueeze(0)
|
||||
|
||||
|
||||
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
|
||||
data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
|
||||
downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]:
|
||||
r"""Interface of Structural Similarity (SSIM) index.
|
||||
Inputs supposed to be in range ``[0, data_range]``.
|
||||
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
|
||||
kernel_sigma: Sigma of normal distribution.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||
full: Return cs map or not.
|
||||
downsample: Perform average pool before SSIM computation. Default: True
|
||||
k1: Algorithm parameter, K1 (small constant).
|
||||
k2: Algorithm parameter, K2 (small constant).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
|
||||
as a tensor of size 2.
|
||||
|
||||
References:
|
||||
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||
Image quality assessment: From error visibility to structural similarity.
|
||||
IEEE Transactions on Image Processing, 13, 600-612.
|
||||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI: `10.1109/TIP.2003.819861`
|
||||
"""
|
||||
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
|
||||
_validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))
|
||||
|
||||
x = x / float(data_range)
|
||||
y = y / float(data_range)
|
||||
|
||||
# Averagepool image if the size is large enough
|
||||
f = max(1, round(min(x.size()[-2:]) / 256))
|
||||
if (f > 1) and downsample:
|
||||
x = F.avg_pool2d(x, kernel_size=f)
|
||||
y = F.avg_pool2d(y, kernel_size=f)
|
||||
|
||||
kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
|
||||
_compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
|
||||
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
|
||||
ssim_val = ssim_map.mean(1)
|
||||
cs = cs_map.mean(1)
|
||||
|
||||
ssim_val = _reduce(ssim_val, reduction)
|
||||
cs = _reduce(cs, reduction)
|
||||
|
||||
if full:
|
||||
return [ssim_val, cs]
|
||||
|
||||
return ssim_val
|
||||
|
||||
|
||||
class SSIMLoss(_Loss):
|
||||
r"""Creates a criterion that measures the structural similarity index error between
|
||||
each element in the input :math:`x` and target :math:`y`.
|
||||
|
||||
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||
|
||||
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
|
||||
|
||||
.. math::
|
||||
SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\
|
||||
ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)}
|
||||
{(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)},
|
||||
|
||||
where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'``
|
||||
(default ``'mean'``), then:
|
||||
|
||||
.. math::
|
||||
SSIMLoss(x, y) =
|
||||
\begin{cases}
|
||||
\operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\
|
||||
\operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.}
|
||||
\end{cases}
|
||||
|
||||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
|
||||
of :math:`n` elements each.
|
||||
|
||||
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
|
||||
In case of 5D input tensors, complex value is returned as a tensor of size 2.
|
||||
|
||||
Args:
|
||||
kernel_size: By default, the mean and covariance of a pixel is obtained
|
||||
by convolution with given filter_size.
|
||||
kernel_sigma: Standard deviation for Gaussian kernel.
|
||||
k1: Coefficient related to c1 in the above equation.
|
||||
k2: Coefficient related to c2 in the above equation.
|
||||
downsample: Perform average pool before SSIM computation. Default: True
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
|
||||
Examples:
|
||||
>>> loss = SSIMLoss()
|
||||
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
||||
>>> y = torch.rand(3, 3, 256, 256)
|
||||
>>> output = loss(x, y)
|
||||
>>> output.backward()
|
||||
|
||||
References:
|
||||
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||
Image quality assessment: From error visibility to structural similarity.
|
||||
IEEE Transactions on Image Processing, 13, 600-612.
|
||||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI:`10.1109/TIP.2003.819861`
|
||||
"""
|
||||
__constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
|
||||
|
||||
def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
|
||||
downsample: bool = True, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Generic loss parameters.
|
||||
self.reduction = reduction
|
||||
|
||||
# Loss-specific parameters.
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# This check might look redundant because kernel size is checked within the ssim function anyway.
|
||||
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
|
||||
assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
|
||||
self.kernel_sigma = kernel_sigma
|
||||
self.k1 = k1
|
||||
self.k2 = k2
|
||||
self.downsample = downsample
|
||||
self.data_range = data_range
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
r"""Computation of Structural Similarity (SSIM) index as a loss function.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
|
||||
Returns:
|
||||
Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors,
|
||||
complex value is returned as a tensor of size 2.
|
||||
"""
|
||||
|
||||
score = ssim(x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, downsample=self.downsample,
|
||||
data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2)
|
||||
return torch.ones_like(score) - score
|
||||
|
||||
|
||||
def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1., k1: float = 0.01,
|
||||
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
||||
kernel: 2D Gaussian kernel.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Full Value of Structural Similarity (SSIM) index.
|
||||
"""
|
||||
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
|
||||
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
|
||||
f'Kernel size: {kernel.size()}')
|
||||
|
||||
c1 = k1 ** 2
|
||||
c2 = k2 ** 2
|
||||
n_channels = x.size(1)
|
||||
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
|
||||
mu_xx = mu_x ** 2
|
||||
mu_yy = mu_y ** 2
|
||||
mu_xy = mu_x * mu_y
|
||||
|
||||
sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
|
||||
sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
|
||||
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
|
||||
|
||||
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
|
||||
cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
|
||||
|
||||
# Structural similarity (SSIM)
|
||||
ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
|
||||
|
||||
ssim_val = ss.mean(dim=(-1, -2))
|
||||
cs = cs.mean(dim=(-1, -2))
|
||||
return ssim_val, cs
|
||||
|
||||
|
||||
def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1., k1: float = 0.01,
|
||||
k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
kernel: 2-D gauss kernel.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Full Value of Complex Structural Similarity (SSIM) index.
|
||||
"""
|
||||
n_channels = x.size(1)
|
||||
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
|
||||
raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
|
||||
f'Kernel size: {kernel.size()}')
|
||||
|
||||
c1 = k1 ** 2
|
||||
c2 = k2 ** 2
|
||||
|
||||
x_real = x[..., 0]
|
||||
x_imag = x[..., 1]
|
||||
y_real = y[..., 0]
|
||||
y_imag = y[..., 1]
|
||||
|
||||
mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
|
||||
mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2)
|
||||
mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2)
|
||||
mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag
|
||||
mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real
|
||||
|
||||
compensation = 1.0
|
||||
|
||||
x_sq = x_real.pow(2) + x_imag.pow(2)
|
||||
y_sq = y_real.pow(2) + y_imag.pow(2)
|
||||
x_y_real = x_real * y_real - x_imag * y_imag
|
||||
x_y_imag = x_real * y_imag + x_imag * y_real
|
||||
|
||||
sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq
|
||||
sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq
|
||||
sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real
|
||||
sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag
|
||||
sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1)
|
||||
mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1)
|
||||
# Set alpha = beta = gamma = 1.
|
||||
cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation)
|
||||
ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation)
|
||||
ssim_map = ssim_map * cs_map
|
||||
|
||||
ssim_val = ssim_map.mean(dim=(-2, -3))
|
||||
cs = cs_map.mean(dim=(-2, -3))
|
||||
|
||||
return ssim_val, cs
|
Loading…
Reference in New Issue