mirror of https://github.com/coqui-ai/TTS.git
Fix Lint checks
This commit is contained in:
parent
fd1036f4ba
commit
ae55bdae6c
|
@ -60,4 +60,3 @@ class ReversalClassifier(nn.Module):
|
|||
target = labels.repeat(input_mask.size(-1), 1).transpose(0, 1).int().long()
|
||||
target[~input_mask] = ignore_index
|
||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||
|
||||
|
|
|
@ -761,20 +761,14 @@ class VitsGeneratorLoss(nn.Module):
|
|||
|
||||
loss += kl_vae_loss
|
||||
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||
|
||||
|
||||
if end2end_info is not None:
|
||||
# do not compute feature loss because for it we need waves segments with the same length
|
||||
'''loss_feat_end2end = (
|
||||
self.feature_loss(feats_real=end2end_info["feats_disc_real"], feats_generated=end2end_info["feats_disc_fake"]) * self.feat_loss_alpha
|
||||
)
|
||||
return_dict["loss_feat_end2end"] = loss_feat_end2end
|
||||
loss += loss_feat_end2end'''
|
||||
|
||||
# gen loss
|
||||
loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha
|
||||
return_dict["loss_gen_end2end"] = loss_gen_end2end
|
||||
loss += loss_gen_end2end
|
||||
|
||||
|
||||
# if do not uses soft dtw
|
||||
if end2end_info["z_predicted"] is not None:
|
||||
# loss KL using GT durations
|
||||
|
@ -793,7 +787,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
else:
|
||||
pass
|
||||
# ToDo: implement soft dtw
|
||||
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
@ -854,7 +848,7 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
|
||||
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
|
||||
)
|
||||
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
|
||||
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
|
||||
return_dict["loss"] += return_dict["loss_disc_end2end"]
|
||||
|
||||
for i, ldr in enumerate(loss_disc_real_end2end):
|
||||
|
|
|
@ -94,9 +94,9 @@ class VitsDiscriminator(nn.Module):
|
|||
mp_scores, zp_scores, mp_feats, zp_feats = None, None, None, None
|
||||
if self.disc_latent is not None:
|
||||
if m_p is not None:
|
||||
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(1))
|
||||
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(-1))
|
||||
if z_p is not None:
|
||||
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(1))
|
||||
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(-1))
|
||||
|
||||
return x_scores, x_feats, x_hat_scores, x_hat_feats, mp_scores, mp_feats, zp_scores, zp_feats
|
||||
|
||||
|
@ -107,7 +107,6 @@ class LatentDiscriminator(nn.Module):
|
|||
def __init__(self, use_spectral_norm=False, hidden_channels=None):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.hidden_channels = hidden_channels
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
|
@ -122,8 +121,6 @@ class LatentDiscriminator(nn.Module):
|
|||
|
||||
def forward(self, y):
|
||||
fmap = []
|
||||
if self.hidden_channels is not None:
|
||||
y = y.squeeze(1).unsqueeze(-1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = torch.nn.functional.leaky_relu(y, 0.1)
|
||||
|
|
|
@ -20,7 +20,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
|||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
@ -556,7 +556,7 @@ class VitsArgs(Coqpit):
|
|||
use_prosody_encoder_z_p_input: bool = False
|
||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||
use_prosody_enc_emo_classifier: bool = False
|
||||
|
||||
|
||||
use_noise_scale_predictor: bool = False
|
||||
|
||||
use_prosody_conditional_flow_module: bool = False
|
||||
|
@ -567,7 +567,6 @@ class VitsArgs(Coqpit):
|
|||
use_soft_dtw: bool = False
|
||||
|
||||
use_latent_discriminator: bool = False
|
||||
provide_hidden_dim_on_the_latent_discriminator: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
|
@ -686,10 +685,10 @@ class Vits(BaseTTS):
|
|||
|
||||
dp_extra_inp_dim = 0
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
|
@ -724,7 +723,7 @@ class Vits(BaseTTS):
|
|||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
|
@ -734,7 +733,7 @@ class Vits(BaseTTS):
|
|||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
self.pros_enc_emotion_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
out_channels=self.num_emotions,
|
||||
|
@ -817,7 +816,7 @@ class Vits(BaseTTS):
|
|||
periods=self.args.periods_multi_period_discriminator,
|
||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
use_latent_disc=self.args.use_latent_discriminator,
|
||||
hidden_channels=self.args.hidden_channels if self.args.provide_hidden_dim_on_the_latent_discriminator else None,
|
||||
hidden_channels=self.args.hidden_channels,
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
|
@ -952,7 +951,7 @@ class Vits(BaseTTS):
|
|||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
|
||||
print(" > Text Encoder was reinit.")
|
||||
|
||||
|
||||
def init_emotion(self, emotion_manager: EmotionManager):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
"""Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer
|
||||
|
@ -1345,7 +1344,7 @@ class Vits(BaseTTS):
|
|||
z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
|
||||
|
||||
z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True)
|
||||
|
||||
|
||||
# interpolate z if needed
|
||||
z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end)
|
||||
# z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end)
|
||||
|
@ -1505,7 +1504,7 @@ class Vits(BaseTTS):
|
|||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
|
@ -1850,7 +1849,7 @@ class Vits(BaseTTS):
|
|||
|
||||
if style_wav and style_speaker_name is None:
|
||||
raise RuntimeError(
|
||||
f" [!] You must to provide the style_speaker_name for the style_wav !!"
|
||||
" [!] You must to provide the style_speaker_name for the style_wav !!"
|
||||
)
|
||||
|
||||
# get speaker id/d_vector
|
||||
|
|
|
@ -423,7 +423,7 @@ class Synthesizer(object):
|
|||
source_emotion_feature, target_emotion_feature = None, None
|
||||
if source_emotion is not None and target_emotion is not None and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||
)):
|
||||
)): # pylint: disable=R0916
|
||||
if source_emotion and isinstance(source_emotion, str):
|
||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||
getattr(self.tts_config, "model_args", None)
|
||||
|
|
|
@ -53,7 +53,6 @@ config.model_args.prosody_encoder_type = "gst"
|
|||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
config.model_args.provide_hidden_dim_on_the_latent_discriminator = True
|
||||
config.model_args.use_noise_scale_predictor = False
|
||||
|
||||
# enable end2end loss
|
||||
|
|
Loading…
Reference in New Issue