mirror of https://github.com/coqui-ai/TTS.git
Add Noise scale predictor
This commit is contained in:
parent
cbc81b55cb
commit
0d7f8e24b2
|
@ -58,14 +58,14 @@ class VitsDiscriminator(nn.Module):
|
||||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False):
|
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False, hidden_channels=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nets = nn.ModuleList()
|
self.nets = nn.ModuleList()
|
||||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||||
self.disc_latent = None
|
self.disc_latent = None
|
||||||
if use_latent_disc:
|
if use_latent_disc:
|
||||||
self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm)
|
self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm, hidden_channels=hidden_channels)
|
||||||
|
|
||||||
def forward(self, x, x_hat=None, m_p=None, z_p=None):
|
def forward(self, x, x_hat=None, m_p=None, z_p=None):
|
||||||
"""
|
"""
|
||||||
|
@ -104,12 +104,13 @@ class VitsDiscriminator(nn.Module):
|
||||||
class LatentDiscriminator(nn.Module):
|
class LatentDiscriminator(nn.Module):
|
||||||
"""Discriminator with the same architecture as the Univnet SpecDiscriminator"""
|
"""Discriminator with the same architecture as the Univnet SpecDiscriminator"""
|
||||||
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False, hidden_channels=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
@ -121,6 +122,8 @@ class LatentDiscriminator(nn.Module):
|
||||||
|
|
||||||
def forward(self, y):
|
def forward(self, y):
|
||||||
fmap = []
|
fmap = []
|
||||||
|
if self.hidden_channels is not None:
|
||||||
|
y = y.squeeze(1).unsqueeze(-1)
|
||||||
for _, d in enumerate(self.discriminators):
|
for _, d in enumerate(self.discriminators):
|
||||||
y = d(y)
|
y = d(y)
|
||||||
y = torch.nn.functional.leaky_relu(y, 0.1)
|
y = torch.nn.functional.leaky_relu(y, 0.1)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
@ -556,6 +557,7 @@ class VitsArgs(Coqpit):
|
||||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||||
use_prosody_enc_emo_classifier: bool = False
|
use_prosody_enc_emo_classifier: bool = False
|
||||||
|
|
||||||
|
use_noise_scale_predictor: bool = False
|
||||||
|
|
||||||
use_prosody_conditional_flow_module: bool = False
|
use_prosody_conditional_flow_module: bool = False
|
||||||
prosody_conditional_flow_module_on_decoder: bool = False
|
prosody_conditional_flow_module_on_decoder: bool = False
|
||||||
|
@ -565,6 +567,7 @@ class VitsArgs(Coqpit):
|
||||||
use_soft_dtw: bool = False
|
use_soft_dtw: bool = False
|
||||||
|
|
||||||
use_latent_discriminator: bool = False
|
use_latent_discriminator: bool = False
|
||||||
|
provide_hidden_dim_on_the_latent_discriminator: bool = False
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
|
@ -650,8 +653,8 @@ class Vits(BaseTTS):
|
||||||
self.args.kernel_size_text_encoder,
|
self.args.kernel_size_text_encoder,
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0,
|
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else 0,
|
||||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0,
|
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -682,10 +685,10 @@ class Vits(BaseTTS):
|
||||||
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
dp_extra_inp_dim = 0
|
dp_extra_inp_dim = 0
|
||||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module:
|
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
|
||||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module:
|
if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor:
|
||||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
|
@ -756,6 +759,27 @@ class Vits(BaseTTS):
|
||||||
cond_channels=cond_embedding_dim,
|
cond_channels=cond_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_noise_scale_predictor:
|
||||||
|
noise_scale_predictor_input_dim = self.args.hidden_channels
|
||||||
|
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||||
|
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
noise_scale_predictor_input_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
|
self.noise_scale_predictor = RelativePositionTransformer(
|
||||||
|
in_channels=noise_scale_predictor_input_dim,
|
||||||
|
out_channels=self.args.hidden_channels,
|
||||||
|
hidden_channels=noise_scale_predictor_input_dim,
|
||||||
|
hidden_channels_ffn=self.args.hidden_channels_ffn_text_encoder,
|
||||||
|
num_heads=self.args.num_heads_text_encoder,
|
||||||
|
num_layers=4,
|
||||||
|
kernel_size=self.args.kernel_size_text_encoder,
|
||||||
|
dropout_p=self.args.dropout_p_text_encoder,
|
||||||
|
layer_norm_type="2",
|
||||||
|
rel_attn_window_size=4,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_text_enc_spk_reversal_classifier:
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||||
in_channels=self.args.hidden_channels
|
in_channels=self.args.hidden_channels
|
||||||
|
@ -793,6 +817,7 @@ class Vits(BaseTTS):
|
||||||
periods=self.args.periods_multi_period_discriminator,
|
periods=self.args.periods_multi_period_discriminator,
|
||||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||||
use_latent_disc=self.args.use_latent_discriminator,
|
use_latent_disc=self.args.use_latent_discriminator,
|
||||||
|
hidden_channels=self.args.hidden_channels if self.args.provide_hidden_dim_on_the_latent_discriminator else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
|
@ -1203,8 +1228,8 @@ class Vits(BaseTTS):
|
||||||
x,
|
x,
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None,
|
emo_emb=eg if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# reversal speaker loss to force the encoder to be speaker identity free
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
@ -1249,6 +1274,17 @@ class Vits(BaseTTS):
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
if self.args.use_noise_scale_predictor:
|
||||||
|
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||||
|
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||||
|
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||||
|
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||||
|
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||||
|
m_p_expanded = m_p_expanded + m_p_noise_scale * torch.exp(logs_p_expanded)
|
||||||
|
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||||
|
@ -1424,8 +1460,8 @@ class Vits(BaseTTS):
|
||||||
x,
|
x,
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None,
|
emo_emb=eg if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# conditional module
|
# conditional module
|
||||||
|
@ -1470,7 +1506,19 @@ class Vits(BaseTTS):
|
||||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
if self.args.use_noise_scale_predictor:
|
||||||
|
nsp_input = torch.transpose(m_p, 1, -1)
|
||||||
|
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||||
|
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||||
|
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||||
|
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||||
|
z_p = m_p + m_p_noise_scale * torch.exp(logs_p)
|
||||||
|
else:
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
|
||||||
# conditional module
|
# conditional module
|
||||||
if self.args.use_prosody_conditional_flow_module:
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
|
|
|
@ -49,10 +49,13 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
||||||
config.model_args.use_text_enc_emo_classifier = False
|
config.model_args.use_text_enc_emo_classifier = False
|
||||||
config.model_args.use_prosody_encoder_z_p_input = True
|
config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
|
|
||||||
config.model_args.prosody_encoder_type = "vae"
|
config.model_args.prosody_encoder_type = "gst"
|
||||||
config.model_args.detach_prosody_enc_input = True
|
config.model_args.detach_prosody_enc_input = True
|
||||||
|
|
||||||
config.model_args.use_latent_discriminator = False
|
config.model_args.use_latent_discriminator = True
|
||||||
|
config.model_args.provide_hidden_dim_on_the_latent_discriminator = True
|
||||||
|
config.model_args.use_noise_scale_predictor = False
|
||||||
|
|
||||||
# enable end2end loss
|
# enable end2end loss
|
||||||
config.model_args.use_end2end_loss = False
|
config.model_args.use_end2end_loss = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue