Remove noise scale predictor

This commit is contained in:
Edresson Casanova 2022-06-22 14:23:33 -03:00
parent ae1f443b35
commit 25e3221daf
1 changed files with 9 additions and 67 deletions

View File

@ -674,7 +674,6 @@ class VitsArgs(Coqpit):
use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False
use_noise_scale_predictor: bool = False
use_latent_discriminator: bool = False
use_encoder_conditional_module: bool = False
@ -787,8 +786,8 @@ class Vits(BaseTTS):
self.args.kernel_size_text_encoder,
self.args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim,
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
emotion_emb_dim=self.args.emotion_embedding_dim,
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
)
@ -824,10 +823,10 @@ class Vits(BaseTTS):
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
) and not self.args.use_noise_scale_predictor:
):
dp_extra_inp_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
dp_extra_inp_dim += self.args.prosody_embedding_dim
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
@ -943,31 +942,6 @@ class Vits(BaseTTS):
reversal=False,
)
if self.args.use_noise_scale_predictor:
noise_scale_predictor_input_dim = self.args.hidden_channels
if (
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
):
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
noise_scale_predictor_input_dim += self.args.prosody_embedding_dim
self.noise_scale_predictor = RelativePositionTransformer(
in_channels=noise_scale_predictor_input_dim,
out_channels=self.args.hidden_channels,
hidden_channels=noise_scale_predictor_input_dim,
hidden_channels_ffn=self.args.hidden_channels_ffn_text_encoder,
num_heads=self.args.num_heads_text_encoder,
num_layers=4,
kernel_size=self.args.kernel_size_text_encoder,
dropout_p=self.args.dropout_p_text_encoder,
layer_norm_type="2",
rel_attn_window_size=4,
)
if self.args.use_emotion_embedding_squeezer:
self.emotion_embedding_squeezer = nn.Linear(
in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim
@ -1513,8 +1487,8 @@ class Vits(BaseTTS):
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
emo_emb=eg,
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
)
@ -1603,22 +1577,6 @@ class Vits(BaseTTS):
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p_expanded, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
nsp_input = torch.cat(
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
nsp_input = torch.cat(
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
m_p_expanded = m_p_expanded + m_p_noise_scale * torch.exp(logs_p_expanded)
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
@ -1794,8 +1752,8 @@ class Vits(BaseTTS):
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
emo_emb=eg,
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
)
@ -1848,23 +1806,7 @@ class Vits(BaseTTS):
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
nsp_input = torch.cat(
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
nsp_input = torch.cat(
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
z_p = m_p + m_p_noise_scale * torch.exp(logs_p)
else:
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
if self.args.use_z_decoder:
cond_x = x