Add support to use the speaker encoder as loss function in VITS model

This commit is contained in:
Edresson 2021-09-03 07:37:43 -03:00 committed by Eren Gölge
parent 9b011b1cb3
commit 690b37d0ab
3 changed files with 56 additions and 12 deletions

View File

@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True

View File

@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
@ -599,6 +600,9 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_real,
loss_duration,
fine_tuning_mode=False,
use_speaker_encoder_as_loss=False,
gt_spk_emb=None,
syn_spk_emb=None
):
"""
Shapes:
@ -632,6 +636,12 @@ class VitsGeneratorLoss(nn.Module):
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
loss += loss_se
return_dict["loss_spk_encoder"] = loss_se
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -195,6 +195,10 @@ class VitsArgs(Coqpit):
embedded_language_dim: int = 4
num_languages: int = 0
fine_tuning_mode: bool = False
use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = ""
class Vits(BaseTTS):
@ -370,6 +374,18 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
else:
self.speaker_encoder = None
def init_multilingual(self, config: Coqpit, data: List = None):
"""Initialize multilingual modules of a model.
@ -427,6 +443,7 @@ class Vits(BaseTTS):
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
waveform=None,
) -> Dict:
"""Forward pass of the model.
@ -461,7 +478,6 @@ class Vits(BaseTTS):
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -508,17 +524,36 @@ class Vits(BaseTTS):
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
else:
gt_spk_emb, syn_spk_emb = None, None
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb
}
)
return outputs
@ -596,7 +631,6 @@ class Vits(BaseTTS):
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
@ -713,6 +747,7 @@ class Vits(BaseTTS):
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
waveform=waveform,
)
else:
outputs = self.forward(
@ -721,30 +756,25 @@ class Vits(BaseTTS):
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
waveform=waveform,
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
wav_seg = segment(
waveform.transpose(1, 2),
outputs["slice_ids"] * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
self.wav_seg_disc_cache = wav_seg
outputs["waveform_seg"] = wav_seg
self.wav_seg_disc_cache = outputs["waveform_seg"]
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
outputs["model_outputs"], wav_seg
outputs["model_outputs"], outputs["waveform_seg"]
)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform=wav_seg.float(),
waveform= outputs["waveform_seg"].float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
@ -755,6 +785,9 @@ class Vits(BaseTTS):
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"],
fine_tuning_mode=self.args.fine_tuning_mode,
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=outputs["gt_spk_emb"],
syn_spk_emb=outputs["syn_spk_emb"]
)
# ignore duration loss if fine tuning mode is on
if not self.args.fine_tuning_mode: