mirror of https://github.com/coqui-ai/TTS.git
Add emotion consistency loss
This commit is contained in:
parent
cc3821332b
commit
5090034fd1
|
@ -114,7 +114,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
feat_loss_alpha: float = 1.0
|
feat_loss_alpha: float = 1.0
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
speaker_encoder_loss_alpha: float = 1.0
|
consistency_loss_alpha: float = 1.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
|
|
@ -532,7 +532,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.feat_loss_alpha = c.feat_loss_alpha
|
self.feat_loss_alpha = c.feat_loss_alpha
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.mel_loss_alpha = c.mel_loss_alpha
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
|
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
c.audio.hop_length,
|
c.audio.hop_length,
|
||||||
|
@ -586,8 +586,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
return l
|
return l
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
def cosine_similarity_loss(gt_cons_emb, syn_cons_emb):
|
||||||
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
return -torch.nn.functional.cosine_similarity(gt_cons_emb, syn_cons_emb).mean()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -602,9 +602,9 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_fake,
|
feats_disc_fake,
|
||||||
feats_disc_real,
|
feats_disc_real,
|
||||||
loss_duration,
|
loss_duration,
|
||||||
use_speaker_encoder_as_loss=False,
|
use_encoder_consistency_loss=False,
|
||||||
gt_spk_emb=None,
|
gt_cons_emb=None,
|
||||||
syn_spk_emb=None,
|
syn_cons_emb=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -635,10 +635,10 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
|
||||||
if use_speaker_encoder_as_loss:
|
if use_encoder_consistency_loss:
|
||||||
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha
|
||||||
loss = loss + loss_se
|
loss = loss + loss_enc
|
||||||
return_dict["loss_spk_encoder"] = loss_se
|
return_dict["loss_consistency_enc"] = loss_enc
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
|
|
@ -429,10 +429,10 @@ class VitsArgs(Coqpit):
|
||||||
use_speaker_encoder_as_loss (bool):
|
use_speaker_encoder_as_loss (bool):
|
||||||
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||||
|
|
||||||
speaker_encoder_config_path (str):
|
encoder_config_path (str):
|
||||||
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
speaker_encoder_model_path (str):
|
encoder_model_path (str):
|
||||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
condition_dp_on_speaker (bool):
|
condition_dp_on_speaker (bool):
|
||||||
|
@ -507,8 +507,9 @@ class VitsArgs(Coqpit):
|
||||||
num_languages: int = 0
|
num_languages: int = 0
|
||||||
language_ids_file: str = None
|
language_ids_file: str = None
|
||||||
use_speaker_encoder_as_loss: bool = False
|
use_speaker_encoder_as_loss: bool = False
|
||||||
speaker_encoder_config_path: str = ""
|
use_emotion_encoder_as_loss: bool = False
|
||||||
speaker_encoder_model_path: str = ""
|
encoder_config_path: str = ""
|
||||||
|
encoder_model_path: str = ""
|
||||||
condition_dp_on_speaker: bool = True
|
condition_dp_on_speaker: bool = True
|
||||||
freeze_encoder: bool = False
|
freeze_encoder: bool = False
|
||||||
freeze_DP: bool = False
|
freeze_DP: bool = False
|
||||||
|
@ -559,6 +560,7 @@ class Vits(BaseTTS):
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
self.init_multilingual(config)
|
self.init_multilingual(config)
|
||||||
self.init_emotion(config, emotion_manager)
|
self.init_emotion(config, emotion_manager)
|
||||||
|
self.init_consistency_loss()
|
||||||
|
|
||||||
self.length_scale = self.args.length_scale
|
self.length_scale = self.args.length_scale
|
||||||
self.noise_scale = self.args.noise_scale
|
self.noise_scale = self.args.noise_scale
|
||||||
|
@ -661,15 +663,21 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_d_vector_file:
|
if self.args.use_d_vector_file:
|
||||||
self._init_d_vector()
|
self._init_d_vector()
|
||||||
|
|
||||||
# TODO: make this a function
|
def init_consistency_loss(self):
|
||||||
if self.args.use_speaker_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
|
||||||
if self.speaker_manager.encoder is None and (
|
|
||||||
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
|
if self.speaker_manager.encoder is None and (
|
||||||
|
not self.args.encoder_model_path or not self.args.encoder_config_path
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
" [!] To use the speaker consistency loss (SCL) you need to specify encoder_model_path and encoder_config_path !!"
|
||||||
|
)
|
||||||
|
# load encoder
|
||||||
|
self.speaker_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
|
||||||
self.speaker_manager.encoder.eval()
|
self.speaker_manager.encoder.eval()
|
||||||
print(" > External Speaker Encoder Loaded !!")
|
print(" > External Speaker Encoder Loaded !!")
|
||||||
|
|
||||||
|
@ -677,16 +685,34 @@ class Vits(BaseTTS):
|
||||||
hasattr(self.speaker_manager.encoder, "audio_config")
|
hasattr(self.speaker_manager.encoder, "audio_config")
|
||||||
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
|
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
|
||||||
):
|
):
|
||||||
self.audio_transform = torchaudio.transforms.Resample(
|
|
||||||
orig_freq=self.audio_config["sample_rate"],
|
|
||||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
|
||||||
)
|
|
||||||
# pylint: disable=W0101,W0105
|
# pylint: disable=W0101,W0105
|
||||||
self.audio_transform = torchaudio.transforms.Resample(
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
orig_freq=self.config.audio.sample_rate,
|
orig_freq=self.config.audio["sample_rate"],
|
||||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.args.use_emotion_encoder_as_loss:
|
||||||
|
if self.emotion_manager.encoder is None and (
|
||||||
|
not self.args.encoder_model_path or not self.args.encoder_config_path
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
" [!] To use the emotion consistency loss (ECL) you need to specify encoder_model_path and encoder_config_path !!"
|
||||||
|
)
|
||||||
|
# load encoder
|
||||||
|
self.emotion_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path)
|
||||||
|
self.emotion_manager.encoder.eval()
|
||||||
|
print(" > External Emotion Encoder Loaded !!")
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.emotion_manager.encoder, "audio_config")
|
||||||
|
and self.config.audio["sample_rate"] != self.emotion_manager.encoder.audio_config["sample_rate"]
|
||||||
|
):
|
||||||
|
# pylint: disable=W0101,W0105
|
||||||
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=self.config.audio["sample_rate"],
|
||||||
|
new_freq=self.emotion_manager.encoder.audio_config["sample_rate"],
|
||||||
|
)
|
||||||
|
|
||||||
def _init_speaker_embedding(self):
|
def _init_speaker_embedding(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if self.num_speakers > 0:
|
if self.num_speakers > 0:
|
||||||
|
@ -896,8 +922,8 @@ class Vits(BaseTTS):
|
||||||
- m_q: :math:`[B, C, T_dec]`
|
- m_q: :math:`[B, C, T_dec]`
|
||||||
- logs_q: :math:`[B, C, T_dec]`
|
- logs_q: :math:`[B, C, T_dec]`
|
||||||
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
|
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
|
||||||
- gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
- gt_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||||
"""
|
"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||||
|
@ -944,7 +970,8 @@ class Vits(BaseTTS):
|
||||||
pad_short=True,
|
pad_short=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
|
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
|
||||||
|
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
|
||||||
# concate generated and GT waveforms
|
# concate generated and GT waveforms
|
||||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||||
|
|
||||||
|
@ -953,12 +980,15 @@ class Vits(BaseTTS):
|
||||||
if self.audio_transform is not None:
|
if self.audio_transform is not None:
|
||||||
wavs_batch = self.audio_transform(wavs_batch)
|
wavs_batch = self.audio_transform(wavs_batch)
|
||||||
|
|
||||||
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
|
if next(encoder.parameters()).device != wavs_batch.device:
|
||||||
|
encoder = encoder.to(wavs_batch.device)
|
||||||
|
|
||||||
|
pred_embs = encoder.forward(wavs_batch, l2_norm=True)
|
||||||
|
|
||||||
# split generated and GT speaker embeddings
|
# split generated and GT speaker embeddings
|
||||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
gt_cons_emb, syn_cons_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||||
else:
|
else:
|
||||||
gt_spk_emb, syn_spk_emb = None, None
|
gt_cons_emb, syn_cons_emb = None, None
|
||||||
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
|
@ -971,8 +1001,8 @@ class Vits(BaseTTS):
|
||||||
"m_q": m_q,
|
"m_q": m_q,
|
||||||
"logs_q": logs_q,
|
"logs_q": logs_q,
|
||||||
"waveform_seg": wav_seg,
|
"waveform_seg": wav_seg,
|
||||||
"gt_spk_emb": gt_spk_emb,
|
"gt_cons_emb": gt_cons_emb,
|
||||||
"syn_spk_emb": syn_spk_emb,
|
"syn_cons_emb": syn_cons_emb,
|
||||||
"slice_ids": slice_ids,
|
"slice_ids": slice_ids,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -1058,6 +1088,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
|
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
|
||||||
"""Inference for voice conversion
|
"""Inference for voice conversion
|
||||||
|
@ -1201,9 +1232,9 @@ class Vits(BaseTTS):
|
||||||
feats_disc_fake=feats_disc_fake,
|
feats_disc_fake=feats_disc_fake,
|
||||||
feats_disc_real=feats_disc_real,
|
feats_disc_real=feats_disc_real,
|
||||||
loss_duration=self.model_outputs_cache["loss_duration"],
|
loss_duration=self.model_outputs_cache["loss_duration"],
|
||||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss,
|
||||||
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1584,9 +1615,9 @@ class Vits(BaseTTS):
|
||||||
language_manager = LanguageManager.init_from_config(config)
|
language_manager = LanguageManager.init_from_config(config)
|
||||||
emotion_manager = EmotionManager.init_from_config(config)
|
emotion_manager = EmotionManager.init_from_config(config)
|
||||||
|
|
||||||
if config.model_args.speaker_encoder_model_path:
|
if config.model_args.encoder_model_path:
|
||||||
speaker_manager.init_encoder(
|
speaker_manager.init_encoder(
|
||||||
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
|
config.model_args.encoder_model_path, config.model_args.encoder_config_path
|
||||||
)
|
)
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
||||||
|
|
||||||
|
|
|
@ -151,8 +151,8 @@ class ModelManager(object):
|
||||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||||
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
||||||
|
|
||||||
# update the scale_path.npy file path in the model config.json
|
# update the scale_path.npy file path in the model config.json
|
||||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||||
|
@ -166,10 +166,10 @@ class ModelManager(object):
|
||||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||||
|
|
||||||
# update the speaker_encoder file path in the model config.json to the current path
|
# update the speaker_encoder file path in the model config.json to the current path
|
||||||
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
self._update_path("encoder_model_path", encoder_model_path, config_path)
|
||||||
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
self._update_path("model_args.encoder_model_path", encoder_model_path, config_path)
|
||||||
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
self._update_path("encoder_config_path", encoder_config_path, config_path)
|
||||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
self._update_path("model_args.encoder_config_path", encoder_config_path, config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_path(field_name, new_path, config_path):
|
def _update_path(field_name, new_path, config_path):
|
||||||
|
|
|
@ -36,15 +36,20 @@ config.audio.trim_db = 60
|
||||||
# active multispeaker d-vec mode
|
# active multispeaker d-vec mode
|
||||||
config.model_args.use_speaker_embedding = True
|
config.model_args.use_speaker_embedding = True
|
||||||
config.model_args.use_d_vector_file = False
|
config.model_args.use_d_vector_file = False
|
||||||
config.model_args.d_vector_file = None
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.speaker_embedding_channels = 128
|
||||||
config.model_args.d_vector_dim = 256
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
# emotion
|
# emotion
|
||||||
config.model_args.use_external_emotions_embeddings = True
|
config.model_args.use_external_emotions_embeddings = False
|
||||||
config.model_args.use_emotion_embedding = False
|
config.model_args.use_emotion_embedding = True
|
||||||
config.model_args.emotion_embedding_dim = 256
|
config.model_args.emotion_embedding_dim = 256
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
|
||||||
|
# consistency loss
|
||||||
|
# config.model_args.use_emotion_encoder_as_loss = True
|
||||||
|
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||||
|
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
@ -69,7 +74,7 @@ continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
speaker_id = "ljspeech-1"
|
speaker_id = "ljspeech-1"
|
||||||
emotion_id = "ljspeech-1"
|
emotion_id = "ljspeech-3"
|
||||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
continue_emotion_path = os.path.join(continue_path, "speakers.json")
|
continue_emotion_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue