mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1069 from WeberJulian/condition_dp_speaker
Add argument to enable dp speaker conditioning in VITS
This commit is contained in:
commit
2ca2c8c431
|
@ -171,6 +171,9 @@ class VitsArgs(Coqpit):
|
||||||
speaker_encoder_model_path (str):
|
speaker_encoder_model_path (str):
|
||||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
|
condition_dp_on_speaker (bool):
|
||||||
|
Condition the duration predictor on the speaker embedding. Defaults to True.
|
||||||
|
|
||||||
freeze_encoder (bool):
|
freeze_encoder (bool):
|
||||||
Freeze the encoder weigths during training. Defaults to False.
|
Freeze the encoder weigths during training. Defaults to False.
|
||||||
|
|
||||||
|
@ -233,6 +236,7 @@ class VitsArgs(Coqpit):
|
||||||
use_speaker_encoder_as_loss: bool = False
|
use_speaker_encoder_as_loss: bool = False
|
||||||
speaker_encoder_config_path: str = ""
|
speaker_encoder_config_path: str = ""
|
||||||
speaker_encoder_model_path: str = ""
|
speaker_encoder_model_path: str = ""
|
||||||
|
condition_dp_on_speaker: bool = True
|
||||||
freeze_encoder: bool = False
|
freeze_encoder: bool = False
|
||||||
freeze_DP: bool = False
|
freeze_DP: bool = False
|
||||||
freeze_PE: bool = False
|
freeze_PE: bool = False
|
||||||
|
@ -349,7 +353,7 @@ class Vits(BaseTTS):
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
args.dropout_p_duration_predictor,
|
||||||
4,
|
4,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -358,7 +362,7 @@ class Vits(BaseTTS):
|
||||||
256,
|
256,
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
args.dropout_p_duration_predictor,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -595,12 +599,15 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
attn_durations = attn.sum(3)
|
attn_durations = attn.sum(3)
|
||||||
|
g_dp = None
|
||||||
|
if self.args.condition_dp_on_speaker:
|
||||||
|
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
loss_duration = self.duration_predictor(
|
loss_duration = self.duration_predictor(
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
x_mask,
|
x_mask,
|
||||||
attn_durations,
|
attn_durations,
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g_dp,
|
||||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = loss_duration / torch.sum(x_mask)
|
loss_duration = loss_duration / torch.sum(x_mask)
|
||||||
|
@ -609,7 +616,7 @@ class Vits(BaseTTS):
|
||||||
log_durations = self.duration_predictor(
|
log_durations = self.duration_predictor(
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
x_mask,
|
x_mask,
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g_dp,
|
||||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
@ -685,10 +692,10 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
|
|
Loading…
Reference in New Issue