Add emotion consistency loss

This commit is contained in:
Edresson Casanova 2022-03-15 12:35:00 +00:00
parent cc3821332b
commit 5090034fd1
5 changed files with 88 additions and 52 deletions

View File

@ -114,7 +114,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0
consistency_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True

View File

@ -532,7 +532,7 @@ class VitsGeneratorLoss(nn.Module):
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
self.consistency_loss_alpha = c.consistency_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
@ -586,8 +586,8 @@ class VitsGeneratorLoss(nn.Module):
return l
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
def cosine_similarity_loss(gt_cons_emb, syn_cons_emb):
return -torch.nn.functional.cosine_similarity(gt_cons_emb, syn_cons_emb).mean()
def forward(
self,
@ -602,9 +602,9 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_fake,
feats_disc_real,
loss_duration,
use_speaker_encoder_as_loss=False,
gt_spk_emb=None,
syn_spk_emb=None,
use_encoder_consistency_loss=False,
gt_cons_emb=None,
syn_cons_emb=None,
):
"""
Shapes:
@ -635,10 +635,10 @@ class VitsGeneratorLoss(nn.Module):
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss = loss + loss_se
return_dict["loss_spk_encoder"] = loss_se
if use_encoder_consistency_loss:
loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha
loss = loss + loss_enc
return_dict["loss_consistency_enc"] = loss_enc
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -429,10 +429,10 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
speaker_encoder_config_path (str):
encoder_config_path (str):
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
speaker_encoder_model_path (str):
encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
condition_dp_on_speaker (bool):
@ -507,8 +507,9 @@ class VitsArgs(Coqpit):
num_languages: int = 0
language_ids_file: str = None
use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = ""
use_emotion_encoder_as_loss: bool = False
encoder_config_path: str = ""
encoder_model_path: str = ""
condition_dp_on_speaker: bool = True
freeze_encoder: bool = False
freeze_DP: bool = False
@ -559,6 +560,7 @@ class Vits(BaseTTS):
self.init_multispeaker(config)
self.init_multilingual(config)
self.init_emotion(config, emotion_manager)
self.init_consistency_loss()
self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale
@ -661,15 +663,21 @@ class Vits(BaseTTS):
if self.args.use_d_vector_file:
self._init_d_vector()
# TODO: make this a function
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.encoder is None and (
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
def init_consistency_loss(self):
if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
raise RuntimeError(
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
)
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.encoder is None and (
not self.args.encoder_model_path or not self.args.encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify encoder_model_path and encoder_config_path !!"
)
# load encoder
self.speaker_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!")
@ -677,15 +685,33 @@ class Vits(BaseTTS):
hasattr(self.speaker_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
):
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
orig_freq=self.config.audio["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
elif self.args.use_emotion_encoder_as_loss:
if self.emotion_manager.encoder is None and (
not self.args.encoder_model_path or not self.args.encoder_config_path
):
raise RuntimeError(
" [!] To use the emotion consistency loss (ECL) you need to specify encoder_model_path and encoder_config_path !!"
)
# load encoder
self.emotion_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
self.emotion_manager.encoder.eval()
print(" > External Emotion Encoder Loaded !!")
if (
hasattr(self.emotion_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.emotion_manager.encoder.audio_config["sample_rate"]
):
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"],
new_freq=self.emotion_manager.encoder.audio_config["sample_rate"],
)
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
@ -896,8 +922,8 @@ class Vits(BaseTTS):
- m_q: :math:`[B, C, T_dec]`
- logs_q: :math:`[B, C, T_dec]`
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
- gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
- gt_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
"""
outputs = {}
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -944,7 +970,8 @@ class Vits(BaseTTS):
pad_short=True,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -953,12 +980,15 @@ class Vits(BaseTTS):
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
if next(encoder.parameters()).device != wavs_batch.device:
encoder = encoder.to(wavs_batch.device)
pred_embs = encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
gt_cons_emb, syn_cons_emb = torch.chunk(pred_embs, 2, dim=0)
else:
gt_spk_emb, syn_spk_emb = None, None
gt_cons_emb, syn_cons_emb = None, None
outputs.update(
{
@ -971,8 +1001,8 @@ class Vits(BaseTTS):
"m_q": m_q,
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb,
"gt_cons_emb": gt_cons_emb,
"syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids,
}
)
@ -1058,6 +1088,7 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
@torch.no_grad()
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
"""Inference for voice conversion
@ -1201,9 +1232,9 @@ class Vits(BaseTTS):
feats_disc_fake=feats_disc_fake,
feats_disc_real=feats_disc_real,
loss_duration=self.model_outputs_cache["loss_duration"],
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
)
return self.model_outputs_cache, loss_dict
@ -1584,9 +1615,9 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config)
emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path:
if config.model_args.encoder_model_path:
speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
config.model_args.encoder_model_path, config.model_args.encoder_config_path
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)

View File

@ -151,8 +151,8 @@ class ModelManager(object):
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
encoder_config_path = os.path.join(output_path, "config_se.json")
encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
@ -166,10 +166,10 @@ class ModelManager(object):
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
# update the speaker_encoder file path in the model config.json to the current path
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
self._update_path("encoder_model_path", encoder_model_path, config_path)
self._update_path("model_args.encoder_model_path", encoder_model_path, config_path)
self._update_path("encoder_config_path", encoder_config_path, config_path)
self._update_path("model_args.encoder_config_path", encoder_config_path, config_path)
@staticmethod
def _update_path(field_name, new_path, config_path):

View File

@ -36,15 +36,20 @@ config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = None
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256
# emotion
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.use_external_emotions_embeddings = False
config.model_args.use_emotion_embedding = True
config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
config.save_json(config_path)
@ -69,7 +74,7 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-1"
emotion_id = "ljspeech-3"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json")