mirror of https://github.com/coqui-ai/TTS.git
Remove noise scale predictor
This commit is contained in:
parent
ae1f443b35
commit
25e3221daf
|
@ -674,7 +674,6 @@ class VitsArgs(Coqpit):
|
|||
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||
use_prosody_enc_emo_classifier: bool = False
|
||||
|
||||
use_noise_scale_predictor: bool = False
|
||||
use_latent_discriminator: bool = False
|
||||
|
||||
use_encoder_conditional_module: bool = False
|
||||
|
@ -787,8 +786,8 @@ class Vits(BaseTTS):
|
|||
self.args.kernel_size_text_encoder,
|
||||
self.args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
|
||||
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
||||
)
|
||||
|
||||
|
@ -824,10 +823,10 @@ class Vits(BaseTTS):
|
|||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
) and not self.args.use_noise_scale_predictor:
|
||||
):
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
|
@ -943,31 +942,6 @@ class Vits(BaseTTS):
|
|||
reversal=False,
|
||||
)
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
noise_scale_predictor_input_dim = self.args.hidden_channels
|
||||
if (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
):
|
||||
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
noise_scale_predictor_input_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.noise_scale_predictor = RelativePositionTransformer(
|
||||
in_channels=noise_scale_predictor_input_dim,
|
||||
out_channels=self.args.hidden_channels,
|
||||
hidden_channels=noise_scale_predictor_input_dim,
|
||||
hidden_channels_ffn=self.args.hidden_channels_ffn_text_encoder,
|
||||
num_heads=self.args.num_heads_text_encoder,
|
||||
num_layers=4,
|
||||
kernel_size=self.args.kernel_size_text_encoder,
|
||||
dropout_p=self.args.dropout_p_text_encoder,
|
||||
layer_norm_type="2",
|
||||
rel_attn_window_size=4,
|
||||
)
|
||||
|
||||
if self.args.use_emotion_embedding_squeezer:
|
||||
self.emotion_embedding_squeezer = nn.Linear(
|
||||
in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim
|
||||
|
@ -1513,8 +1487,8 @@ class Vits(BaseTTS):
|
|||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
emo_emb=eg,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
|
@ -1603,22 +1577,6 @@ class Vits(BaseTTS):
|
|||
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||
m_p_expanded = m_p_expanded + m_p_noise_scale * torch.exp(logs_p_expanded)
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
||||
|
@ -1794,8 +1752,8 @@ class Vits(BaseTTS):
|
|||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
emo_emb=eg,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
|
@ -1848,23 +1806,7 @@ class Vits(BaseTTS):
|
|||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||
z_p = m_p + m_p_noise_scale * torch.exp(logs_p)
|
||||
else:
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
|
||||
if self.args.use_z_decoder:
|
||||
cond_x = x
|
||||
|
|
Loading…
Reference in New Issue