linter fixes

This commit is contained in:
Eren Gölge 2021-04-07 12:36:03 +02:00
parent 9782d9ea5d
commit f890454de3
3 changed files with 122 additions and 48 deletions

View File

@ -20,7 +20,7 @@ class Synthesizer(object):
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder """General 🐸 TTS interface for inference. It takes a tts and a vocoder
model and synthesize speech from the provided text. model and synthesize speech from the provided text.
The text is divided into a list of sentences using `pysbd` and synthesize The text is divided into a list of sentences using `pysbd` and synthesize
speech on each sentence separately. speech on each sentence separately.
If you have certain special characters in your text, you need to handle If you have certain special characters in your text, you need to handle

View File

@ -77,8 +77,8 @@ class GANDataset(Dataset):
"""Pad samples shorter than the output sequence length""" """Pad samples shorter than the output sequence length"""
if len(audio) < self.seq_len: if len(audio) < self.seq_len:
audio = np.pad(audio, (0, self.seq_len - len(audio)), audio = np.pad(audio, (0, self.seq_len - len(audio)),
mode='constant', mode='constant',
constant_values=0.0) constant_values=0.0)
if mel is not None and mel.shape[1] < self.feat_frame_len: if mel is not None and mel.shape[1] < self.feat_frame_len:
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]

View File

@ -1,36 +1,67 @@
import torch import torch
import librosa
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
class TorchSTFT(nn.Module): # pylint: disable=abstract-method class TorchSTFT(nn.Module):
# TODO: move this to audio.py with a transparent torch API. """TODO: Merge this with audio.py"""
def __init__(self, n_fft, hop_length, win_length, pad_mode='reflect', window='hann_window'): def __init__(self,
n_fft,
hop_length,
win_length,
window='hann_window',
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False):
""" Torch based STFT operation """ """ Torch based STFT operation """
super(TorchSTFT, self).__init__() super(TorchSTFT, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.win_length = win_length self.win_length = win_length
self.pad_mode = pad_mode self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.window = nn.Parameter(getattr(torch, window)(win_length), self.window = nn.Parameter(getattr(torch, window)(win_length),
requires_grad=False) requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x): def __call__(self, x):
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
# B x D x T x 2 # B x D x T x 2
o = torch.stft(x, o = torch.stft(
self.n_fft, x.squeeze(1),
self.hop_length, self.n_fft,
self.win_length, self.hop_length,
self.window, self.win_length,
center=True, self.window,
pad_mode=self.pad_mode, # needs to be compatible with audio.py center=True,
normalized=False, pad_mode="reflect", # compatible with audio.py
onesided=True, normalized=False,
return_complex=False) onesided=True,
return_complex=False)
M = o[:, :, :, 0] M = o[:, :, :, 0]
P = o[:, :, :, 1] P = o[:, :, :, 1]
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
self.mel_basis = torch.from_numpy(mel_basis).float()
################################# #################################
@ -39,7 +70,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
class STFTLoss(nn.Module): class STFTLoss(nn.Module):
""" Single scale STFT Loss """ """ STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_fft, hop_length, win_length): def __init__(self, n_fft, hop_length, win_length):
super(STFTLoss, self).__init__() super(STFTLoss, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
@ -57,7 +90,9 @@ class STFTLoss(nn.Module):
return loss_mag, loss_sc return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module): class MultiScaleSTFTLoss(torch.nn.Module):
""" Multi scale STFT loss """ """ Multi-scale STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, def __init__(self,
n_ffts=(1024, 2048, 512), n_ffts=(1024, 2048, 512),
hop_lengths=(120, 240, 50), hop_lengths=(120, 240, 50),
@ -79,9 +114,30 @@ class MultiScaleSTFTLoss(torch.nn.Module):
loss_mag /= N loss_mag /= N
return loss_mag, loss_sc return loss_mag, loss_sc
class L1SpecLoss(nn.Module):
""" L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True):
super().__init__()
self.use_mel = use_mel
self.stft = TorchSTFT(n_fft,
hop_length,
win_length,
sample_rate=sample_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
n_mels=n_mels,
use_mel=use_mel)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
return loss_mag
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
""" Multiscale STFT loss for multi band model outputs """ """ Multiscale STFT loss for multi band model outputs.
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
# pylint: disable=no-self-use # pylint: disable=no-self-use
def forward(self, y_hat, y): def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2]) y_hat = y_hat.view(-1, 1, y_hat.shape[2])
@ -143,9 +199,12 @@ class MelganFeatureLoss(nn.Module):
# pylint: disable=no-self-use # pylint: disable=no-self-use
def forward(self, fake_feats, real_feats): def forward(self, fake_feats, real_feats):
loss_feats = 0 loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats): num_feats = 0
loss_feats += self.loss_func(fake_feat, real_feat) for idx, _ in enumerate(fake_feats):
loss_feats /= len(fake_feats) + len(real_feats) for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
loss_feats += self.loss_func(fake_feat, real_feat)
num_feats += 1
loss_feats = loss_feats / num_feats
return loss_feats return loss_feats
@ -198,24 +257,31 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
class GeneratorLoss(nn.Module): class GeneratorLoss(nn.Module):
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
losses. It allows to experiment with different combinations of loss functions with different models by just
changing configurations.
Args:
C (AttrDict): model configuration.
"""
def __init__(self, C): def __init__(self, C):
""" Compute Generator Loss values depending on training super().__init__()
configuration """
super(GeneratorLoss, self).__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together." " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False
self.use_mse_gan_loss = C.use_mse_gan_loss self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False
self.use_hinge_gan_loss = C.use_hinge_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False
self.use_feat_match_loss = C.use_feat_match_loss self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False
self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False
self.stft_loss_weight = C.stft_loss_weight self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0
self.subband_stft_loss_weight = C.subband_stft_loss_weight self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0
self.mse_gan_loss_weight = C.mse_G_loss_weight self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0
self.hinge_gan_loss_weight = C.hinge_G_loss_weight self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0
self.feat_match_loss_weight = C.feat_match_loss_weight self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0
self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0
if C.use_stft_loss: if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
@ -227,6 +293,9 @@ class GeneratorLoss(nn.Module):
self.hinge_loss = HingeGLoss() self.hinge_loss = HingeGLoss()
if C.use_feat_match_loss: if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss() self.feat_match_loss = MelganFeatureLoss()
if C.use_l1_spec_loss:
assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate']
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None): def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
gen_loss = 0 gen_loss = 0
@ -235,35 +304,41 @@ class GeneratorLoss(nn.Module):
# STFT Loss # STFT Loss
if self.use_stft_loss: if self.use_stft_loss:
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1)) stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
return_dict['G_stft_loss_mg'] = stft_loss_mg return_dict['G_stft_loss_mg'] = stft_loss_mg
return_dict['G_stft_loss_sc'] = stft_loss_sc return_dict['G_stft_loss_sc'] = stft_loss_sc
gen_loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# L1 Spec loss
if self.use_l1_spec_loss:
l1_spec_loss = self.l1_spec_loss(y_hat, y)
return_dict['G_l1_spec_loss'] = l1_spec_loss
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
# subband STFT Loss # subband STFT Loss
if self.use_subband_stft_loss: if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
gen_loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# multiscale MSE adversarial loss # multiscale MSE adversarial loss
if self.use_mse_gan_loss and scores_fake is not None: if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
return_dict['G_mse_fake_loss'] = mse_fake_loss return_dict['G_mse_fake_loss'] = mse_fake_loss
adv_loss += self.mse_gan_loss_weight * mse_fake_loss adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
# multiscale Hinge adversarial loss # multiscale Hinge adversarial loss
if self.use_hinge_gan_loss and not scores_fake is not None: if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
return_dict['G_hinge_fake_loss'] = hinge_fake_loss return_dict['G_hinge_fake_loss'] = hinge_fake_loss
adv_loss += self.hinge_gan_loss_weight * hinge_fake_loss adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss # Feature Matching Loss
if self.use_feat_match_loss and not feats_fake: if self.use_feat_match_loss and not feats_fake is None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real) feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict['G_feat_match_loss'] = feat_match_loss return_dict['G_feat_match_loss'] = feat_match_loss
adv_loss += self.feat_match_loss_weight * feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
return_dict['G_loss'] = gen_loss + adv_loss return_dict['G_loss'] = gen_loss + adv_loss
return_dict['G_gen_loss'] = gen_loss return_dict['G_gen_loss'] = gen_loss
return_dict['G_adv_loss'] = adv_loss return_dict['G_adv_loss'] = adv_loss
@ -271,10 +346,9 @@ class GeneratorLoss(nn.Module):
class DiscriminatorLoss(nn.Module): class DiscriminatorLoss(nn.Module):
""" Compute Discriminator Loss values depending on training """Like ```GeneratorLoss```"""
configuration """
def __init__(self, C): def __init__(self, C):
super(DiscriminatorLoss, self).__init__() super().__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together." " [!] Cannot use HingeGANLoss and MSEGANLoss together."