Add argument to enable dp speaker conditioning

This commit is contained in:
WeberJulian 2022-01-06 15:07:27 +01:00
parent e1accb6e28
commit e778bad626
1 changed files with 13 additions and 6 deletions

View File

@ -171,6 +171,9 @@ class VitsArgs(Coqpit):
speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
condition_dp_on_speaker (bool):
Condition the duration predictor on the speaker embedding. Defaults to True.
freeze_encoder (bool):
Freeze the encoder weigths during training. Defaults to False.
@ -233,6 +236,7 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = ""
condition_dp_on_speaker: bool = True
freeze_encoder: bool = False
freeze_DP: bool = False
freeze_PE: bool = False
@ -349,7 +353,7 @@ class Vits(BaseTTS):
3,
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim,
)
else:
@ -358,7 +362,7 @@ class Vits(BaseTTS):
256,
3,
args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim,
)
@ -595,12 +599,15 @@ class Vits(BaseTTS):
# duration predictor
attn_durations = attn.sum(3)
g_dp = None
if self.args.condition_dp_on_speaker:
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
if self.args.use_sdp:
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
@ -609,7 +616,7 @@ class Vits(BaseTTS):
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
@ -685,10 +692,10 @@ class Vits(BaseTTS):
if self.args.use_sdp:
logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
)
else:
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)