mirror of https://github.com/coqui-ai/TTS.git
Add argument to enable dp speaker conditioning
This commit is contained in:
parent
e1accb6e28
commit
e778bad626
|
@ -171,6 +171,9 @@ class VitsArgs(Coqpit):
|
|||
speaker_encoder_model_path (str):
|
||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||
|
||||
condition_dp_on_speaker (bool):
|
||||
Condition the duration predictor on the speaker embedding. Defaults to True.
|
||||
|
||||
freeze_encoder (bool):
|
||||
Freeze the encoder weigths during training. Defaults to False.
|
||||
|
||||
|
@ -233,6 +236,7 @@ class VitsArgs(Coqpit):
|
|||
use_speaker_encoder_as_loss: bool = False
|
||||
speaker_encoder_config_path: str = ""
|
||||
speaker_encoder_model_path: str = ""
|
||||
condition_dp_on_speaker: bool = True
|
||||
freeze_encoder: bool = False
|
||||
freeze_DP: bool = False
|
||||
freeze_PE: bool = False
|
||||
|
@ -349,7 +353,7 @@ class Vits(BaseTTS):
|
|||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
4,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
else:
|
||||
|
@ -358,7 +362,7 @@ class Vits(BaseTTS):
|
|||
256,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
|
@ -595,12 +599,15 @@ class Vits(BaseTTS):
|
|||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
g_dp = None
|
||||
if self.args.condition_dp_on_speaker:
|
||||
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
|
||||
if self.args.use_sdp:
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
g=g_dp,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
|
@ -609,7 +616,7 @@ class Vits(BaseTTS):
|
|||
log_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
g=g_dp,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
|
@ -685,10 +692,10 @@ class Vits(BaseTTS):
|
|||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
|
Loading…
Reference in New Issue