mirror of https://github.com/coqui-ai/TTS.git
Add conditional module
This commit is contained in:
parent
bce4a41b9c
commit
a6c8fea192
|
@ -549,6 +549,8 @@ class VitsArgs(Coqpit):
|
|||
prosody_encoder_num_tokens: int = 5
|
||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||
|
||||
use_prosody_conditional_flow_module: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
|
@ -633,8 +635,8 @@ class Vits(BaseTTS):
|
|||
self.args.kernel_size_text_encoder,
|
||||
self.args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -664,9 +666,16 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder:
|
||||
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||
|
||||
dp_extra_inp_dim = 0
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module:
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||
self.args.hidden_channels + dp_extra_inp_dim,
|
||||
192,
|
||||
3,
|
||||
self.args.dropout_p_duration_predictor,
|
||||
|
@ -676,7 +685,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||
self.args.hidden_channels + dp_extra_inp_dim,
|
||||
256,
|
||||
3,
|
||||
self.args.dropout_p_duration_predictor,
|
||||
|
@ -698,11 +707,27 @@ class Vits(BaseTTS):
|
|||
hidden_channels=256,
|
||||
)
|
||||
|
||||
if self.args.use_prosody_conditional_flow_module:
|
||||
cond_embedding_dim = 0
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
cond_embedding_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.prosody_conditional_module = ResidualCouplingBlocks(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.kernel_size_flow,
|
||||
dilation_rate=self.args.dilation_rate_flow,
|
||||
num_layers=2,
|
||||
cond_channels=cond_embedding_dim,
|
||||
)
|
||||
|
||||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.hidden_channels
|
||||
+ self.args.emotion_embedding_dim
|
||||
+ self.args.prosody_embedding_dim,
|
||||
+ dp_extra_inp_dim,
|
||||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
|
@ -1097,7 +1122,17 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None,
|
||||
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
||||
)
|
||||
|
||||
# conditional module
|
||||
if self.args.use_prosody_conditional_flow_module:
|
||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
l_text_speaker = None
|
||||
|
@ -1246,7 +1281,17 @@ class Vits(BaseTTS):
|
|||
z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None,
|
||||
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
||||
)
|
||||
|
||||
# conditional module
|
||||
if self.args.use_prosody_conditional_flow_module:
|
||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
|
|
|
@ -47,6 +47,8 @@ config.model_args.emotion_embedding_dim = 256
|
|||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||
|
||||
|
||||
config.model_args.use_prosody_conditional_flow_module = True
|
||||
# consistency loss
|
||||
# config.model_args.use_emotion_encoder_as_loss = True
|
||||
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||
|
|
Loading…
Reference in New Issue