Add argument to enable dp speaker conditioning

This commit is contained in:
WeberJulian 2022-01-06 15:07:27 +01:00
parent e1accb6e28
commit e778bad626
1 changed files with 13 additions and 6 deletions

View File

@ -171,6 +171,9 @@ class VitsArgs(Coqpit):
speaker_encoder_model_path (str): speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
condition_dp_on_speaker (bool):
Condition the duration predictor on the speaker embedding. Defaults to True.
freeze_encoder (bool): freeze_encoder (bool):
Freeze the encoder weigths during training. Defaults to False. Freeze the encoder weigths during training. Defaults to False.
@ -233,6 +236,7 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss: bool = False use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = "" speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = "" speaker_encoder_model_path: str = ""
condition_dp_on_speaker: bool = True
freeze_encoder: bool = False freeze_encoder: bool = False
freeze_DP: bool = False freeze_DP: bool = False
freeze_PE: bool = False freeze_PE: bool = False
@ -349,7 +353,7 @@ class Vits(BaseTTS):
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,
4, 4,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
else: else:
@ -358,7 +362,7 @@ class Vits(BaseTTS):
256, 256,
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
@ -595,12 +599,15 @@ class Vits(BaseTTS):
# duration predictor # duration predictor
attn_durations = attn.sum(3) attn_durations = attn.sum(3)
g_dp = None
if self.args.condition_dp_on_speaker:
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
if self.args.use_sdp: if self.args.use_sdp:
loss_duration = self.duration_predictor( loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
attn_durations, attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
) )
loss_duration = loss_duration / torch.sum(x_mask) loss_duration = loss_duration / torch.sum(x_mask)
@ -609,7 +616,7 @@ class Vits(BaseTTS):
log_durations = self.duration_predictor( log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
) )
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
@ -685,10 +692,10 @@ class Vits(BaseTTS):
if self.args.use_sdp: if self.args.use_sdp:
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
) )
else: else:
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)