Fix Lint checks

This commit is contained in:
Edresson Casanova 2022-06-06 14:59:21 -03:00
parent fd1036f4ba
commit ae55bdae6c
6 changed files with 18 additions and 30 deletions

View File

@ -60,4 +60,3 @@ class ReversalClassifier(nn.Module):
target = labels.repeat(input_mask.size(-1), 1).transpose(0, 1).int().long()
target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)

View File

@ -761,20 +761,14 @@ class VitsGeneratorLoss(nn.Module):
loss += kl_vae_loss
return_dict["loss_kl_vae"] = kl_vae_loss
if end2end_info is not None:
# do not compute feature loss because for it we need waves segments with the same length
'''loss_feat_end2end = (
self.feature_loss(feats_real=end2end_info["feats_disc_real"], feats_generated=end2end_info["feats_disc_fake"]) * self.feat_loss_alpha
)
return_dict["loss_feat_end2end"] = loss_feat_end2end
loss += loss_feat_end2end'''
# gen loss
loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha
return_dict["loss_gen_end2end"] = loss_gen_end2end
loss += loss_gen_end2end
# if do not uses soft dtw
if end2end_info["z_predicted"] is not None:
# loss KL using GT durations
@ -793,7 +787,7 @@ class VitsGeneratorLoss(nn.Module):
else:
pass
# ToDo: implement soft dtw
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
@ -854,7 +848,7 @@ class VitsDiscriminatorLoss(nn.Module):
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
)
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
return_dict["loss"] += return_dict["loss_disc_end2end"]
for i, ldr in enumerate(loss_disc_real_end2end):

View File

@ -94,9 +94,9 @@ class VitsDiscriminator(nn.Module):
mp_scores, zp_scores, mp_feats, zp_feats = None, None, None, None
if self.disc_latent is not None:
if m_p is not None:
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(1))
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(-1))
if z_p is not None:
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(1))
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(-1))
return x_scores, x_feats, x_hat_scores, x_hat_feats, mp_scores, mp_feats, zp_scores, zp_feats
@ -107,7 +107,6 @@ class LatentDiscriminator(nn.Module):
def __init__(self, use_spectral_norm=False, hidden_channels=None):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.hidden_channels = hidden_channels
self.discriminators = nn.ModuleList(
[
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
@ -122,8 +121,6 @@ class LatentDiscriminator(nn.Module):
def forward(self, y):
fmap = []
if self.hidden_channels is not None:
y = y.squeeze(1).unsqueeze(-1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = torch.nn.functional.leaky_relu(y, 0.1)

View File

@ -20,7 +20,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -556,7 +556,7 @@ class VitsArgs(Coqpit):
use_prosody_encoder_z_p_input: bool = False
use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False
use_noise_scale_predictor: bool = False
use_prosody_conditional_flow_module: bool = False
@ -567,7 +567,6 @@ class VitsArgs(Coqpit):
use_soft_dtw: bool = False
use_latent_discriminator: bool = False
provide_hidden_dim_on_the_latent_discriminator: bool = False
detach_dp_input: bool = True
use_language_embedding: bool = False
@ -686,10 +685,10 @@ class Vits(BaseTTS):
dp_extra_inp_dim = 0
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
dp_extra_inp_dim += self.args.emotion_embedding_dim
dp_extra_inp_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
dp_extra_inp_dim += self.args.prosody_embedding_dim
dp_extra_inp_dim += self.args.prosody_embedding_dim
if self.args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
@ -724,7 +723,7 @@ class Vits(BaseTTS):
num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
)
else:
else:
raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
@ -734,7 +733,7 @@ class Vits(BaseTTS):
out_channels=self.num_speakers,
hidden_channels=256,
)
if self.args.use_prosody_enc_emo_classifier:
if self.args.use_prosody_enc_emo_classifier:
self.pros_enc_emotion_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
out_channels=self.num_emotions,
@ -817,7 +816,7 @@ class Vits(BaseTTS):
periods=self.args.periods_multi_period_discriminator,
use_spectral_norm=self.args.use_spectral_norm_disriminator,
use_latent_disc=self.args.use_latent_discriminator,
hidden_channels=self.args.hidden_channels if self.args.provide_hidden_dim_on_the_latent_discriminator else None,
hidden_channels=self.args.hidden_channels,
)
def init_multispeaker(self, config: Coqpit):
@ -952,7 +951,7 @@ class Vits(BaseTTS):
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.")
def init_emotion(self, emotion_manager: EmotionManager):
# pylint: disable=attribute-defined-outside-init
"""Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer
@ -1345,7 +1344,7 @@ class Vits(BaseTTS):
z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True)
# interpolate z if needed
z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end)
# z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end)
@ -1505,7 +1504,7 @@ class Vits(BaseTTS):
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
@ -1850,7 +1849,7 @@ class Vits(BaseTTS):
if style_wav and style_speaker_name is None:
raise RuntimeError(
f" [!] You must to provide the style_speaker_name for the style_wav !!"
" [!] You must to provide the style_speaker_name for the style_wav !!"
)
# get speaker id/d_vector

View File

@ -423,7 +423,7 @@ class Synthesizer(object):
source_emotion_feature, target_emotion_feature = None, None
if source_emotion is not None and target_emotion is not None and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
)):
)): # pylint: disable=R0916
if source_emotion and isinstance(source_emotion, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)

View File

@ -53,7 +53,6 @@ config.model_args.prosody_encoder_type = "gst"
config.model_args.detach_prosody_enc_input = True
config.model_args.use_latent_discriminator = True
config.model_args.provide_hidden_dim_on_the_latent_discriminator = True
config.model_args.use_noise_scale_predictor = False
# enable end2end loss