Merge pull request #1069 from WeberJulian/condition_dp_speaker

Add argument to enable dp speaker conditioning in VITS
This commit is contained in:
Eren Gölge 2022-02-06 20:12:35 +01:00 committed by GitHub
commit 2ca2c8c431
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 6 deletions

View File

@ -171,6 +171,9 @@ class VitsArgs(Coqpit):
speaker_encoder_model_path (str): speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
condition_dp_on_speaker (bool):
Condition the duration predictor on the speaker embedding. Defaults to True.
freeze_encoder (bool): freeze_encoder (bool):
Freeze the encoder weigths during training. Defaults to False. Freeze the encoder weigths during training. Defaults to False.
@ -233,6 +236,7 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss: bool = False use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = "" speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = "" speaker_encoder_model_path: str = ""
condition_dp_on_speaker: bool = True
freeze_encoder: bool = False freeze_encoder: bool = False
freeze_DP: bool = False freeze_DP: bool = False
freeze_PE: bool = False freeze_PE: bool = False
@ -349,7 +353,7 @@ class Vits(BaseTTS):
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,
4, 4,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
else: else:
@ -358,7 +362,7 @@ class Vits(BaseTTS):
256, 256,
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
@ -595,12 +599,15 @@ class Vits(BaseTTS):
# duration predictor # duration predictor
attn_durations = attn.sum(3) attn_durations = attn.sum(3)
g_dp = None
if self.args.condition_dp_on_speaker:
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
if self.args.use_sdp: if self.args.use_sdp:
loss_duration = self.duration_predictor( loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
attn_durations, attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
) )
loss_duration = loss_duration / torch.sum(x_mask) loss_duration = loss_duration / torch.sum(x_mask)
@ -609,7 +616,7 @@ class Vits(BaseTTS):
log_durations = self.duration_predictor( log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
) )
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
@ -685,10 +692,10 @@ class Vits(BaseTTS):
if self.args.use_sdp: if self.args.use_sdp:
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
) )
else: else:
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)