mirror of https://github.com/coqui-ai/TTS.git
Add support to use the speaker encoder as loss function in VITS model
This commit is contained in:
parent
a3901032f4
commit
3cd889a9d4
|
@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
speaker_encoder_loss_alpha: float = 1.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
c.audio.hop_length,
|
||||
|
@ -599,6 +600,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
feats_disc_real,
|
||||
loss_duration,
|
||||
fine_tuning_mode=False,
|
||||
use_speaker_encoder_as_loss=False,
|
||||
gt_spk_emb=None,
|
||||
syn_spk_emb=None
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -632,6 +636,12 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
||||
loss += loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -195,6 +195,10 @@ class VitsArgs(Coqpit):
|
|||
embedded_language_dim: int = 4
|
||||
num_languages: int = 0
|
||||
fine_tuning_mode: bool = False
|
||||
use_speaker_encoder_as_loss: bool = False
|
||||
speaker_encoder_config_path: str = ""
|
||||
speaker_encoder_model_path: str = ""
|
||||
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -370,6 +374,18 @@ class Vits(BaseTTS):
|
|||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||
for param in self.speaker_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
else:
|
||||
self.speaker_encoder = None
|
||||
|
||||
def init_multilingual(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multilingual modules of a model.
|
||||
|
||||
|
@ -427,6 +443,7 @@ class Vits(BaseTTS):
|
|||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
waveform=None,
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
|
@ -461,7 +478,6 @@ class Vits(BaseTTS):
|
|||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
|
@ -508,17 +524,36 @@ class Vits(BaseTTS):
|
|||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||
else:
|
||||
gt_spk_emb, syn_spk_emb = None, None
|
||||
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"slice_ids": slice_ids,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"m_q": m_q,
|
||||
"logs_q": logs_q,
|
||||
"waveform_seg": wav_seg,
|
||||
"gt_spk_emb": gt_spk_emb,
|
||||
"syn_spk_emb": syn_spk_emb
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -596,7 +631,6 @@ class Vits(BaseTTS):
|
|||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"slice_ids": slice_ids,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
|
@ -713,6 +747,7 @@ class Vits(BaseTTS):
|
|||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
waveform=waveform,
|
||||
)
|
||||
else:
|
||||
outputs = self.forward(
|
||||
|
@ -721,30 +756,25 @@ class Vits(BaseTTS):
|
|||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
waveform=waveform,
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
self.y_disc_cache = None
|
||||
self.wav_seg_disc_cache = None
|
||||
self.y_disc_cache = outputs["model_outputs"]
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
outputs["slice_ids"] * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
self.wav_seg_disc_cache = wav_seg
|
||||
outputs["waveform_seg"] = wav_seg
|
||||
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
||||
|
||||
# compute discriminator scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
||||
outputs["model_outputs"], wav_seg
|
||||
outputs["model_outputs"], outputs["waveform_seg"]
|
||||
)
|
||||
|
||||
# compute losses
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
waveform_hat=outputs["model_outputs"].float(),
|
||||
waveform=wav_seg.float(),
|
||||
waveform= outputs["waveform_seg"].float(),
|
||||
z_p=outputs["z_p"].float(),
|
||||
logs_q=outputs["logs_q"].float(),
|
||||
m_p=outputs["m_p"].float(),
|
||||
|
@ -755,6 +785,9 @@ class Vits(BaseTTS):
|
|||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"],
|
||||
fine_tuning_mode=self.args.fine_tuning_mode,
|
||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||
gt_spk_emb=outputs["gt_spk_emb"],
|
||||
syn_spk_emb=outputs["syn_spk_emb"]
|
||||
)
|
||||
# ignore duration loss if fine tuning mode is on
|
||||
if not self.args.fine_tuning_mode:
|
||||
|
|
Loading…
Reference in New Issue