Add emotion consistency loss

This commit is contained in:
Edresson Casanova 2022-03-15 12:35:00 +00:00
parent cc3821332b
commit 5090034fd1
5 changed files with 88 additions and 52 deletions

View File

@ -114,7 +114,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0 mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0 consistency_loss_alpha: float = 1.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True

View File

@ -532,7 +532,7 @@ class VitsGeneratorLoss(nn.Module):
self.feat_loss_alpha = c.feat_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha self.consistency_loss_alpha = c.consistency_loss_alpha
self.stft = TorchSTFT( self.stft = TorchSTFT(
c.audio.fft_size, c.audio.fft_size,
c.audio.hop_length, c.audio.hop_length,
@ -586,8 +586,8 @@ class VitsGeneratorLoss(nn.Module):
return l return l
@staticmethod @staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): def cosine_similarity_loss(gt_cons_emb, syn_cons_emb):
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() return -torch.nn.functional.cosine_similarity(gt_cons_emb, syn_cons_emb).mean()
def forward( def forward(
self, self,
@ -602,9 +602,9 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_fake, feats_disc_fake,
feats_disc_real, feats_disc_real,
loss_duration, loss_duration,
use_speaker_encoder_as_loss=False, use_encoder_consistency_loss=False,
gt_spk_emb=None, gt_cons_emb=None,
syn_spk_emb=None, syn_cons_emb=None,
): ):
""" """
Shapes: Shapes:
@ -635,10 +635,10 @@ class VitsGeneratorLoss(nn.Module):
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss: if use_encoder_consistency_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha
loss = loss + loss_se loss = loss + loss_enc
return_dict["loss_spk_encoder"] = loss_se return_dict["loss_consistency_enc"] = loss_enc
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl return_dict["loss_kl"] = loss_kl

View File

@ -429,10 +429,10 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss (bool): use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
speaker_encoder_config_path (str): encoder_config_path (str):
Path to the file speaker encoder config file, to use for SCL. Defaults to "". Path to the file speaker encoder config file, to use for SCL. Defaults to "".
speaker_encoder_model_path (str): encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
condition_dp_on_speaker (bool): condition_dp_on_speaker (bool):
@ -507,8 +507,9 @@ class VitsArgs(Coqpit):
num_languages: int = 0 num_languages: int = 0
language_ids_file: str = None language_ids_file: str = None
use_speaker_encoder_as_loss: bool = False use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = "" use_emotion_encoder_as_loss: bool = False
speaker_encoder_model_path: str = "" encoder_config_path: str = ""
encoder_model_path: str = ""
condition_dp_on_speaker: bool = True condition_dp_on_speaker: bool = True
freeze_encoder: bool = False freeze_encoder: bool = False
freeze_DP: bool = False freeze_DP: bool = False
@ -559,6 +560,7 @@ class Vits(BaseTTS):
self.init_multispeaker(config) self.init_multispeaker(config)
self.init_multilingual(config) self.init_multilingual(config)
self.init_emotion(config, emotion_manager) self.init_emotion(config, emotion_manager)
self.init_consistency_loss()
self.length_scale = self.args.length_scale self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale self.noise_scale = self.args.noise_scale
@ -661,15 +663,21 @@ class Vits(BaseTTS):
if self.args.use_d_vector_file: if self.args.use_d_vector_file:
self._init_d_vector() self._init_d_vector()
# TODO: make this a function def init_consistency_loss(self):
if self.args.use_speaker_encoder_as_loss: if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
if self.speaker_manager.encoder is None and (
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
):
raise RuntimeError( raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" " [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
) )
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.encoder is None and (
not self.args.encoder_model_path or not self.args.encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify encoder_model_path and encoder_config_path !!"
)
# load encoder
self.speaker_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
self.speaker_manager.encoder.eval() self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!") print(" > External Speaker Encoder Loaded !!")
@ -677,16 +685,34 @@ class Vits(BaseTTS):
hasattr(self.speaker_manager.encoder, "audio_config") hasattr(self.speaker_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"] and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
): ):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
# pylint: disable=W0101,W0105 # pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample( self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate, orig_freq=self.config.audio["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
) )
elif self.args.use_emotion_encoder_as_loss:
if self.emotion_manager.encoder is None and (
not self.args.encoder_model_path or not self.args.encoder_config_path
):
raise RuntimeError(
" [!] To use the emotion consistency loss (ECL) you need to specify encoder_model_path and encoder_config_path !!"
)
# load encoder
self.emotion_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
self.emotion_manager.encoder.eval()
print(" > External Emotion Encoder Loaded !!")
if (
hasattr(self.emotion_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.emotion_manager.encoder.audio_config["sample_rate"]
):
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"],
new_freq=self.emotion_manager.encoder.audio_config["sample_rate"],
)
def _init_speaker_embedding(self): def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
@ -896,8 +922,8 @@ class Vits(BaseTTS):
- m_q: :math:`[B, C, T_dec]` - m_q: :math:`[B, C, T_dec]`
- logs_q: :math:`[B, C, T_dec]` - logs_q: :math:`[B, C, T_dec]`
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
- gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - gt_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
""" """
outputs = {} outputs = {}
sid, g, lid, eid, eg = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -944,7 +970,8 @@ class Vits(BaseTTS):
pad_short=True, pad_short=True,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
# concate generated and GT waveforms # concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0) wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -953,12 +980,15 @@ class Vits(BaseTTS):
if self.audio_transform is not None: if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch) wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) if next(encoder.parameters()).device != wavs_batch.device:
encoder = encoder.to(wavs_batch.device)
pred_embs = encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings # split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) gt_cons_emb, syn_cons_emb = torch.chunk(pred_embs, 2, dim=0)
else: else:
gt_spk_emb, syn_spk_emb = None, None gt_cons_emb, syn_cons_emb = None, None
outputs.update( outputs.update(
{ {
@ -971,8 +1001,8 @@ class Vits(BaseTTS):
"m_q": m_q, "m_q": m_q,
"logs_q": logs_q, "logs_q": logs_q,
"waveform_seg": wav_seg, "waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb, "gt_cons_emb": gt_cons_emb,
"syn_spk_emb": syn_spk_emb, "syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids, "slice_ids": slice_ids,
} }
) )
@ -1058,6 +1088,7 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None): def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
"""Inference for voice conversion """Inference for voice conversion
@ -1201,9 +1232,9 @@ class Vits(BaseTTS):
feats_disc_fake=feats_disc_fake, feats_disc_fake=feats_disc_fake,
feats_disc_real=feats_disc_real, feats_disc_real=feats_disc_real,
loss_duration=self.model_outputs_cache["loss_duration"], loss_duration=self.model_outputs_cache["loss_duration"],
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss,
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -1584,9 +1615,9 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config) language_manager = LanguageManager.init_from_config(config)
emotion_manager = EmotionManager.init_from_config(config) emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path: if config.model_args.encoder_model_path:
speaker_manager.init_encoder( speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path config.model_args.encoder_model_path, config.model_args.encoder_config_path
) )
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)

View File

@ -151,8 +151,8 @@ class ModelManager(object):
output_stats_path = os.path.join(output_path, "scale_stats.npy") output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
speaker_encoder_config_path = os.path.join(output_path, "config_se.json") encoder_config_path = os.path.join(output_path, "config_se.json")
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar") encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
# update the scale_path.npy file path in the model config.json # update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path) self._update_path("audio.stats_path", output_stats_path, config_path)
@ -166,10 +166,10 @@ class ModelManager(object):
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
# update the speaker_encoder file path in the model config.json to the current path # update the speaker_encoder file path in the model config.json to the current path
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) self._update_path("encoder_model_path", encoder_model_path, config_path)
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) self._update_path("model_args.encoder_model_path", encoder_model_path, config_path)
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) self._update_path("encoder_config_path", encoder_config_path, config_path)
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) self._update_path("model_args.encoder_config_path", encoder_config_path, config_path)
@staticmethod @staticmethod
def _update_path(field_name, new_path, config_path): def _update_path(field_name, new_path, config_path):

View File

@ -36,15 +36,20 @@ config.audio.trim_db = 60
# active multispeaker d-vec mode # active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = None config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256 config.model_args.d_vector_dim = 256
# emotion # emotion
config.model_args.use_external_emotions_embeddings = True config.model_args.use_external_emotions_embeddings = False
config.model_args.use_emotion_embedding = False config.model_args.use_emotion_embedding = True
config.model_args.emotion_embedding_dim = 256 config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
config.save_json(config_path) config.save_json(config_path)
@ -69,7 +74,7 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path) continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav") out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1" speaker_id = "ljspeech-1"
emotion_id = "ljspeech-1" emotion_id = "ljspeech-3"
continue_speakers_path = os.path.join(continue_path, "speakers.json") continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json") continue_emotion_path = os.path.join(continue_path, "speakers.json")