mirror of https://github.com/coqui-ai/TTS.git
linter fixes
This commit is contained in:
parent
9782d9ea5d
commit
f890454de3
|
@ -20,7 +20,7 @@ class Synthesizer(object):
|
||||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||||
model and synthesize speech from the provided text.
|
model and synthesize speech from the provided text.
|
||||||
|
|
||||||
The text is divided into a list of sentences using `pysbd` and synthesize
|
The text is divided into a list of sentences using `pysbd` and synthesize
|
||||||
speech on each sentence separately.
|
speech on each sentence separately.
|
||||||
|
|
||||||
If you have certain special characters in your text, you need to handle
|
If you have certain special characters in your text, you need to handle
|
||||||
|
|
|
@ -77,8 +77,8 @@ class GANDataset(Dataset):
|
||||||
"""Pad samples shorter than the output sequence length"""
|
"""Pad samples shorter than the output sequence length"""
|
||||||
if len(audio) < self.seq_len:
|
if len(audio) < self.seq_len:
|
||||||
audio = np.pad(audio, (0, self.seq_len - len(audio)),
|
audio = np.pad(audio, (0, self.seq_len - len(audio)),
|
||||||
mode='constant',
|
mode='constant',
|
||||||
constant_values=0.0)
|
constant_values=0.0)
|
||||||
|
|
||||||
if mel is not None and mel.shape[1] < self.feat_frame_len:
|
if mel is not None and mel.shape[1] < self.feat_frame_len:
|
||||||
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
|
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
|
||||||
|
|
|
@ -1,36 +1,67 @@
|
||||||
import torch
|
import torch
|
||||||
|
import librosa
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module):
|
||||||
# TODO: move this to audio.py with a transparent torch API.
|
"""TODO: Merge this with audio.py"""
|
||||||
def __init__(self, n_fft, hop_length, win_length, pad_mode='reflect', window='hann_window'):
|
def __init__(self,
|
||||||
|
n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
window='hann_window',
|
||||||
|
sample_rate=None,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
n_mels=80,
|
||||||
|
use_mel=False):
|
||||||
""" Torch based STFT operation """
|
""" Torch based STFT operation """
|
||||||
super(TorchSTFT, self).__init__()
|
super(TorchSTFT, self).__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
self.pad_mode = pad_mode
|
self.sample_rate = sample_rate
|
||||||
|
self.mel_fmin = mel_fmin
|
||||||
|
self.mel_fmax = mel_fmax
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.use_mel = use_mel
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length),
|
self.window = nn.Parameter(getattr(torch, window)(win_length),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
self.mel_basis = None
|
||||||
|
if use_mel:
|
||||||
|
self._build_mel_basis()
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
padding = int((self.n_fft - self.hop_length) / 2)
|
||||||
|
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
|
||||||
# B x D x T x 2
|
# B x D x T x 2
|
||||||
o = torch.stft(x,
|
o = torch.stft(
|
||||||
self.n_fft,
|
x.squeeze(1),
|
||||||
self.hop_length,
|
self.n_fft,
|
||||||
self.win_length,
|
self.hop_length,
|
||||||
self.window,
|
self.win_length,
|
||||||
center=True,
|
self.window,
|
||||||
pad_mode=self.pad_mode, # needs to be compatible with audio.py
|
center=True,
|
||||||
normalized=False,
|
pad_mode="reflect", # compatible with audio.py
|
||||||
onesided=True,
|
normalized=False,
|
||||||
return_complex=False)
|
onesided=True,
|
||||||
|
return_complex=False)
|
||||||
M = o[:, :, :, 0]
|
M = o[:, :, :, 0]
|
||||||
P = o[:, :, :, 1]
|
P = o[:, :, :, 1]
|
||||||
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||||
|
if self.use_mel:
|
||||||
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
|
return S
|
||||||
|
|
||||||
|
def _build_mel_basis(self):
|
||||||
|
mel_basis = librosa.filters.mel(self.sample_rate,
|
||||||
|
self.n_fft,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
fmin=self.mel_fmin,
|
||||||
|
fmax=self.mel_fmax)
|
||||||
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
|
@ -39,7 +70,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
|
||||||
|
|
||||||
class STFTLoss(nn.Module):
|
class STFTLoss(nn.Module):
|
||||||
""" Single scale STFT Loss """
|
""" STFT loss. Input generate and real waveforms are converted
|
||||||
|
to spectrograms compared with L1 and Spectral convergence losses.
|
||||||
|
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||||
def __init__(self, n_fft, hop_length, win_length):
|
def __init__(self, n_fft, hop_length, win_length):
|
||||||
super(STFTLoss, self).__init__()
|
super(STFTLoss, self).__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -57,7 +90,9 @@ class STFTLoss(nn.Module):
|
||||||
return loss_mag, loss_sc
|
return loss_mag, loss_sc
|
||||||
|
|
||||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||||
""" Multi scale STFT loss """
|
""" Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||||
|
to spectrograms compared with L1 and Spectral convergence losses.
|
||||||
|
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_ffts=(1024, 2048, 512),
|
n_ffts=(1024, 2048, 512),
|
||||||
hop_lengths=(120, 240, 50),
|
hop_lengths=(120, 240, 50),
|
||||||
|
@ -79,9 +114,30 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
||||||
loss_mag /= N
|
loss_mag /= N
|
||||||
return loss_mag, loss_sc
|
return loss_mag, loss_sc
|
||||||
|
|
||||||
|
class L1SpecLoss(nn.Module):
|
||||||
|
""" L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
|
||||||
|
def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True):
|
||||||
|
super().__init__()
|
||||||
|
self.use_mel = use_mel
|
||||||
|
self.stft = TorchSTFT(n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
mel_fmin=mel_fmin,
|
||||||
|
mel_fmax=mel_fmax,
|
||||||
|
n_mels=n_mels,
|
||||||
|
use_mel=use_mel)
|
||||||
|
|
||||||
|
def forward(self, y_hat, y):
|
||||||
|
y_hat_M = self.stft(y_hat)
|
||||||
|
y_M = self.stft(y)
|
||||||
|
# magnitude loss
|
||||||
|
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||||
|
return loss_mag
|
||||||
|
|
||||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||||
""" Multiscale STFT loss for multi band model outputs """
|
""" Multiscale STFT loss for multi band model outputs.
|
||||||
|
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def forward(self, y_hat, y):
|
def forward(self, y_hat, y):
|
||||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||||
|
@ -143,9 +199,12 @@ class MelganFeatureLoss(nn.Module):
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def forward(self, fake_feats, real_feats):
|
def forward(self, fake_feats, real_feats):
|
||||||
loss_feats = 0
|
loss_feats = 0
|
||||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
num_feats = 0
|
||||||
loss_feats += self.loss_func(fake_feat, real_feat)
|
for idx, _ in enumerate(fake_feats):
|
||||||
loss_feats /= len(fake_feats) + len(real_feats)
|
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
|
||||||
|
loss_feats += self.loss_func(fake_feat, real_feat)
|
||||||
|
num_feats += 1
|
||||||
|
loss_feats = loss_feats / num_feats
|
||||||
return loss_feats
|
return loss_feats
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,24 +257,31 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||||
|
|
||||||
|
|
||||||
class GeneratorLoss(nn.Module):
|
class GeneratorLoss(nn.Module):
|
||||||
|
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
|
||||||
|
losses. It allows to experiment with different combinations of loss functions with different models by just
|
||||||
|
changing configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
C (AttrDict): model configuration.
|
||||||
|
"""
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
""" Compute Generator Loss values depending on training
|
super().__init__()
|
||||||
configuration """
|
|
||||||
super(GeneratorLoss, self).__init__()
|
|
||||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||||
|
|
||||||
self.use_stft_loss = C.use_stft_loss
|
self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False
|
||||||
self.use_subband_stft_loss = C.use_subband_stft_loss
|
self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False
|
||||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False
|
||||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False
|
||||||
self.use_feat_match_loss = C.use_feat_match_loss
|
self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False
|
||||||
|
self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False
|
||||||
|
|
||||||
self.stft_loss_weight = C.stft_loss_weight
|
self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0
|
||||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0
|
||||||
self.mse_gan_loss_weight = C.mse_G_loss_weight
|
self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0
|
||||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
|
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0
|
||||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0
|
||||||
|
self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0
|
||||||
|
|
||||||
if C.use_stft_loss:
|
if C.use_stft_loss:
|
||||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||||
|
@ -227,6 +293,9 @@ class GeneratorLoss(nn.Module):
|
||||||
self.hinge_loss = HingeGLoss()
|
self.hinge_loss = HingeGLoss()
|
||||||
if C.use_feat_match_loss:
|
if C.use_feat_match_loss:
|
||||||
self.feat_match_loss = MelganFeatureLoss()
|
self.feat_match_loss = MelganFeatureLoss()
|
||||||
|
if C.use_l1_spec_loss:
|
||||||
|
assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate']
|
||||||
|
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
|
||||||
|
|
||||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
||||||
gen_loss = 0
|
gen_loss = 0
|
||||||
|
@ -235,35 +304,41 @@ class GeneratorLoss(nn.Module):
|
||||||
|
|
||||||
# STFT Loss
|
# STFT Loss
|
||||||
if self.use_stft_loss:
|
if self.use_stft_loss:
|
||||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
|
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||||
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
||||||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||||
gen_loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||||
|
|
||||||
|
# L1 Spec loss
|
||||||
|
if self.use_l1_spec_loss:
|
||||||
|
l1_spec_loss = self.l1_spec_loss(y_hat, y)
|
||||||
|
return_dict['G_l1_spec_loss'] = l1_spec_loss
|
||||||
|
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
|
||||||
|
|
||||||
# subband STFT Loss
|
# subband STFT Loss
|
||||||
if self.use_subband_stft_loss:
|
if self.use_subband_stft_loss:
|
||||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||||
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
||||||
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
||||||
gen_loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||||
|
|
||||||
# multiscale MSE adversarial loss
|
# multiscale MSE adversarial loss
|
||||||
if self.use_mse_gan_loss and scores_fake is not None:
|
if self.use_mse_gan_loss and scores_fake is not None:
|
||||||
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||||
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
||||||
adv_loss += self.mse_gan_loss_weight * mse_fake_loss
|
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
|
||||||
|
|
||||||
# multiscale Hinge adversarial loss
|
# multiscale Hinge adversarial loss
|
||||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||||
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||||
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
||||||
adv_loss += self.hinge_gan_loss_weight * hinge_fake_loss
|
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||||
|
|
||||||
# Feature Matching Loss
|
# Feature Matching Loss
|
||||||
if self.use_feat_match_loss and not feats_fake:
|
if self.use_feat_match_loss and not feats_fake is None:
|
||||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||||
return_dict['G_feat_match_loss'] = feat_match_loss
|
return_dict['G_feat_match_loss'] = feat_match_loss
|
||||||
adv_loss += self.feat_match_loss_weight * feat_match_loss
|
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||||
return_dict['G_loss'] = gen_loss + adv_loss
|
return_dict['G_loss'] = gen_loss + adv_loss
|
||||||
return_dict['G_gen_loss'] = gen_loss
|
return_dict['G_gen_loss'] = gen_loss
|
||||||
return_dict['G_adv_loss'] = adv_loss
|
return_dict['G_adv_loss'] = adv_loss
|
||||||
|
@ -271,10 +346,9 @@ class GeneratorLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorLoss(nn.Module):
|
class DiscriminatorLoss(nn.Module):
|
||||||
""" Compute Discriminator Loss values depending on training
|
"""Like ```GeneratorLoss```"""
|
||||||
configuration """
|
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
super(DiscriminatorLoss, self).__init__()
|
super().__init__()
|
||||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue