Add Emotion Support for the VITS model

This commit is contained in:
Edresson Casanova 2022-03-15 01:16:48 +00:00
parent ad7ce05ac9
commit bd99548016
7 changed files with 311 additions and 46 deletions

View File

@ -152,6 +152,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
# args for multi-speaker synthesis # args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--emotions_file_path", type=str, help="JSON file for emotion model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
parser.add_argument( parser.add_argument(
"--speaker_idx", "--speaker_idx",
@ -165,6 +166,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
help="Target language ID for a multi-lingual TTS model.", help="Target language ID for a multi-lingual TTS model.",
default=None, default=None,
) )
parser.add_argument(
"--emotion_idx",
type=str,
help="Target emotion ID.",
default=None,
)
parser.add_argument( parser.add_argument(
"--speaker_wav", "--speaker_wav",
nargs="+", nargs="+",
@ -254,6 +261,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path = args.model_path model_path = args.model_path
config_path = args.config_path config_path = args.config_path
speakers_file_path = args.speakers_file_path speakers_file_path = args.speakers_file_path
emotions_file_path = args.emotions_file_path
language_ids_file_path = args.language_ids_file_path language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None: if args.vocoder_path is not None:
@ -269,6 +277,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path, model_path,
config_path, config_path,
speakers_file_path, speakers_file_path,
emotions_file_path,
language_ids_file_path, language_ids_file_path,
vocoder_path, vocoder_path,
vocoder_config_path, vocoder_config_path,
@ -315,6 +324,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
style_wav=args.capacitron_style_wav, style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text, style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx, reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx,
) )
# save the results # save the results

View File

@ -428,3 +428,25 @@ class BaseTTS(BaseTrainerModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.") print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.") print(" > `language_ids_file` is updated in the config.json.")
if hasattr(self, "emotion_manager") and self.emotion_manager is not None:
output_path = os.path.join(trainer.output_path, "emotions.json")
if hasattr(trainer.config, "model_args"):
if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file:
self.emotion_manager.save_ids_to_file(output_path)
trainer.config.model_args.emotions_ids_file = output_path
else:
self.emotion_manager.save_embeddings_to_file(output_path)
trainer.config.model_args.external_emotions_embs_file = output_path
else:
if trainer.config.use_emotion_embedding and not trainer.config.external_emotions_embs_file:
self.emotion_manager.save_ids_to_file(output_path)
trainer.config.emotions_ids_file = output_path
else:
self.emotion_manager.save_embeddings_to_file(output_path)
trainer.config.external_emotions_embs_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `emotions.json` is saved to {output_path}.")
print(" > `emotions_ids_file` or `external_emotions_embs_file` is updated in the config.json.")

View File

@ -25,6 +25,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -300,7 +301,7 @@ class VitsDataset(TTSDataset):
"waveform_lens": wav_lens, # (B) "waveform_lens": wav_lens, # (B)
"waveform_rel_lens": wav_rel_lens, "waveform_rel_lens": wav_rel_lens,
"speaker_names": batch["speaker_name"], "speaker_names": batch["speaker_name"],
"language_names": batch["language_name"], "f": batch["language_name"],
"audio_files": batch["wav_file"], "audio_files": batch["wav_file"],
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
} }
@ -529,6 +530,15 @@ class VitsArgs(Coqpit):
speaker_embedding_channels: int = 256 speaker_embedding_channels: int = 256
use_d_vector_file: bool = False use_d_vector_file: bool = False
d_vector_dim: int = 0 d_vector_dim: int = 0
# use emotion embeddings
use_emotion_embedding: bool = False
use_external_emotions_embeddings: bool = False
emotions_ids_file: str = None
external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0
num_emotions: int = 0
detach_dp_input: bool = True detach_dp_input: bool = True
use_language_embedding: bool = False use_language_embedding: bool = False
embedded_language_dim: int = 4 embedded_language_dim: int = 4
@ -584,13 +594,14 @@ class Vits(BaseTTS):
tokenizer: "TTSTokenizer" = None, tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None, speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None, language_manager: LanguageManager = None,
emotion_manager: EmotionManager = None,
): ):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager) super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
self.init_multispeaker(config) self.init_multispeaker(config)
self.init_multilingual(config) self.init_multilingual(config)
self.init_upsampling() self.init_upsampling()
self.init_emotion(config, emotion_manager)
self.length_scale = self.args.length_scale self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale self.noise_scale = self.args.noise_scale
@ -619,7 +630,7 @@ class Vits(BaseTTS):
kernel_size=self.args.kernel_size_posterior_encoder, kernel_size=self.args.kernel_size_posterior_encoder,
dilation_rate=self.args.dilation_rate_posterior_encoder, dilation_rate=self.args.dilation_rate_posterior_encoder,
num_layers=self.args.num_layers_posterior_encoder, num_layers=self.args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim, cond_channels=self.cond_embedding_dim,
) )
self.flow = ResidualCouplingBlocks( self.flow = ResidualCouplingBlocks(
@ -628,7 +639,7 @@ class Vits(BaseTTS):
kernel_size=self.args.kernel_size_flow, kernel_size=self.args.kernel_size_flow,
dilation_rate=self.args.dilation_rate_flow, dilation_rate=self.args.dilation_rate_flow,
num_layers=self.args.num_layers_flow, num_layers=self.args.num_layers_flow,
cond_channels=self.embedded_speaker_dim, cond_channels=self.cond_embedding_dim,
) )
if self.args.use_sdp: if self.args.use_sdp:
@ -638,7 +649,7 @@ class Vits(BaseTTS):
3, 3,
self.args.dropout_p_duration_predictor, self.args.dropout_p_duration_predictor,
4, 4,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, cond_channels=self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
else: else:
@ -647,7 +658,7 @@ class Vits(BaseTTS):
256, 256,
3, 3,
self.args.dropout_p_duration_predictor, self.args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim, cond_channels=self.cond_embedding_dim,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
@ -661,7 +672,7 @@ class Vits(BaseTTS):
self.args.upsample_initial_channel_decoder, self.args.upsample_initial_channel_decoder,
self.args.upsample_rates_decoder, self.args.upsample_rates_decoder,
inference_padding=0, inference_padding=0,
cond_channels=self.embedded_speaker_dim, cond_channels=self.cond_embedding_dim,
conv_pre_weight_norm=False, conv_pre_weight_norm=False,
conv_post_weight_norm=False, conv_post_weight_norm=False,
conv_post_bias=False, conv_post_bias=False,
@ -683,7 +694,7 @@ class Vits(BaseTTS):
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None. data (List, optional): Dataset items to infer number of speakers. Defaults to None.
""" """
self.embedded_speaker_dim = 0 self.cond_embedding_dim = 0
self.num_speakers = self.args.num_speakers self.num_speakers = self.args.num_speakers
self.audio_transform = None self.audio_transform = None
@ -726,14 +737,14 @@ class Vits(BaseTTS):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.") print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels self.cond_embedding_dim += self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.emb_g = nn.Embedding(self.num_speakers, self.args.speaker_embedding_channels)
def _init_d_vector(self): def _init_d_vector(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.embedded_speaker_dim = self.args.d_vector_dim self.cond_embedding_dim += self.args.d_vector_dim
def init_multilingual(self, config: Coqpit): def init_multilingual(self, config: Coqpit):
"""Initialize multilingual modules of a model. """Initialize multilingual modules of a model.
@ -785,9 +796,36 @@ class Vits(BaseTTS):
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.") print(" > Text Encoder was reinit.")
def init_emotion(self, config: Coqpit, emotion_manager: EmotionManager):
# pylint: disable=attribute-defined-outside-init
"""Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer
or with external `embeddings` computed from a emotion encoder model.
You must provide a `emotion_manager` at initialization to set up the emotion modules.
Args:
config (Coqpit): Model configuration.
emotion_manager (Coqpit): Emotion Manager.
"""
self.emotion_manager = emotion_manager
self.num_emotions = self.args.num_emotions
if self.emotion_manager:
self.num_emotions = self.emotion_manager.num_emotions
if self.args.use_emotion_embedding:
if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict): def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid,
"emotion_embeddings": eg, "emotion_ids": eid}
def _freeze_layers(self): def _freeze_layers(self):
if self.args.freeze_encoder: if self.args.freeze_encoder:
@ -817,7 +855,7 @@ class Vits(BaseTTS):
@staticmethod @staticmethod
def _set_cond_input(aux_input: Dict): def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode.""" """Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid = None, None, None sid, g, lid, eid, eg = None, None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"] sid = aux_input["speaker_ids"]
if sid.ndim == 0: if sid.ndim == 0:
@ -832,7 +870,17 @@ class Vits(BaseTTS):
if lid.ndim == 0: if lid.ndim == 0:
lid = lid.unsqueeze_(0) lid = lid.unsqueeze_(0)
return sid, g, lid if "emotion_ids" in aux_input and aux_input["emotion_ids"] is not None:
eid = aux_input["emotion_ids"]
if eid.ndim == 0:
eid = eid.unsqueeze_(0)
if "emotion_embeddings" in aux_input and aux_input["emotion_embeddings"] is not None:
eg = F.normalize(aux_input["emotion_embeddings"]).unsqueeze(-1)
if eg.ndim == 2:
eg = eg.unsqueeze_(0)
return sid, g, lid, eid, eg
def _set_speaker_input(self, aux_input: Dict): def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None) d_vectors = aux_input.get("d_vectors", None)
@ -906,7 +954,7 @@ class Vits(BaseTTS):
y: torch.tensor, y: torch.tensor,
y_lengths: torch.tensor, y_lengths: torch.tensor,
waveform: torch.tensor, waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None},
) -> Dict: ) -> Dict:
"""Forward pass of the model. """Forward pass of the model.
@ -946,11 +994,19 @@ class Vits(BaseTTS):
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
""" """
outputs = {} outputs = {}
sid, g, lid = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
# speaker embedding # speaker embedding
if self.args.use_speaker_embedding and sid is not None: if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1028,7 +1084,7 @@ class Vits(BaseTTS):
@torch.no_grad() @torch.no_grad()
def inference( def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Note: Note:
@ -1048,13 +1104,21 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]` - m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]`
""" """
sid, g, lid = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input) x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding # speaker embedding
if self.args.use_speaker_embedding and sid is not None: if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) g = self.emb_g(sid).unsqueeze(-1)
# emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1187,6 +1251,8 @@ class Vits(BaseTTS):
d_vectors = batch["d_vectors"] d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"] language_ids = batch["language_ids"]
emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"] waveform = batch["waveform"]
# generator pass # generator pass
@ -1196,7 +1262,8 @@ class Vits(BaseTTS):
spec, spec,
spec_lens, spec_lens,
waveform, waveform,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids,
"emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids},
) )
# cache tensors for the generator pass # cache tensors for the generator pass
@ -1322,7 +1389,7 @@ class Vits(BaseTTS):
config = self.config config = self.config
# extract speaker and language info # extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None
if isinstance(sentence_info, list): if isinstance(sentence_info, list):
if len(sentence_info) == 1: if len(sentence_info) == 1:
@ -1333,11 +1400,13 @@ class Vits(BaseTTS):
text, speaker_name, style_wav = sentence_info text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4: elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info text, speaker_name, style_wav, language_name = sentence_info
elif len(sentence_info) == 5:
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
else: else:
text = sentence_info text = sentence_info
# get speaker id/d_vector # get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None
if hasattr(self, "speaker_manager"): if hasattr(self, "speaker_manager"):
if config.use_d_vector_file: if config.use_d_vector_file:
if speaker_name is None: if speaker_name is None:
@ -1354,6 +1423,19 @@ class Vits(BaseTTS):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.ids[language_name] language_id = self.language_manager.ids[language_name]
# get emotion id/embedding
if hasattr(self, "emotion_manager"):
if config.use_external_emotions_embeddings:
if emotion_name is None:
emotion_embedding = self.emotion_manager.get_random_embeddings()
else:
emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
elif config.use_emotion_embedding:
if emotion_name is None:
emotion_id = self.emotion_manager.get_random_id()
else:
emotion_id = self.emotion_manager.ids[emotion_name]
return { return {
"text": text, "text": text,
"speaker_id": speaker_id, "speaker_id": speaker_id,
@ -1361,6 +1443,8 @@ class Vits(BaseTTS):
"d_vector": d_vector, "d_vector": d_vector,
"language_id": language_id, "language_id": language_id,
"language_name": language_name, "language_name": language_name,
"emotion_embedding": emotion_embedding,
"emotion_ids": emotion_id
} }
@torch.no_grad() @torch.no_grad()
@ -1387,6 +1471,8 @@ class Vits(BaseTTS):
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"], language_id=aux_inputs["language_id"],
emotion_embedding=aux_inputs["emotion_embedding"],
emotion_id=aux_inputs["emotion_ids"],
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
).values() ).values()
@ -1401,10 +1487,12 @@ class Vits(BaseTTS):
logger.test_figures(steps, outputs["figures"]) logger.test_figures(steps, outputs["figures"])
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
"""Compute speaker, langugage IDs and d_vector for the batch if necessary.""" """Compute speaker, langugage IDs, d_vector and emotion embeddings for the batch if necessary."""
speaker_ids = None speaker_ids = None
language_ids = None language_ids = None
d_vectors = None d_vectors = None
emotion_embeddings = None
emotion_ids = None
# get numerical speaker ids from speaker names # get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding: if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
@ -1421,15 +1509,33 @@ class Vits(BaseTTS):
d_vectors = torch.FloatTensor(d_vectors) d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names # get language ids from language names
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding: if (
self.language_manager is not None
and self.language_manager.ids
and self.args.use_language_embedding
):
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
if language_ids is not None: if language_ids is not None:
language_ids = torch.LongTensor(language_ids) language_ids = torch.LongTensor(language_ids)
# get emotion embedding
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings:
emotion_mapping = self.emotion_manager.embeddings
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding:
emotion_mapping = self.emotion_manager.embeddings
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
emotion_ids = torch.LongTensor(emotion_ids)
batch["language_ids"] = language_ids batch["language_ids"] = language_ids
batch["d_vectors"] = d_vectors batch["d_vectors"] = d_vectors
batch["speaker_ids"] = speaker_ids batch["speaker_ids"] = speaker_ids
batch["emotion_embeddings"] = emotion_embeddings
batch["emotion_ids"] = emotion_ids
return batch return batch
def format_batch_on_device(self, batch): def format_batch_on_device(self, batch):
@ -1643,12 +1749,13 @@ class Vits(BaseTTS):
tokenizer, new_config = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples) speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config) language_manager = LanguageManager.init_from_config(config)
emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path: if config.model_args.speaker_encoder_model_path:
speaker_manager.init_encoder( speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
) )
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
################################## ##################################

View File

@ -1,10 +1,8 @@
import json import json
import os import os
from typing import Any, Dict, List, Tuple, Union from typing import Any, List
import fsspec import fsspec
import numpy as np
import torch
from coqpit import Coqpit from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default from TTS.config import get_from_config_or_model_args_with_default
@ -34,7 +32,7 @@ class EmotionManager(EmbeddingManager):
computes the embeddings for a given clip or emotion. computes the embeddings for a given clip or emotion.
Args: Args:
emo_embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
emotion_id_file_path (str, optional): Path to the metafile that maps emotion names to ids used by emotion_id_file_path (str, optional): Path to the metafile that maps emotion names to ids used by
TTS models. Defaults to "". TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the emotion encoder model file. Defaults to "". encoder_model_path (str, optional): Path to the emotion encoder model file. Defaults to "".
@ -50,14 +48,14 @@ class EmotionManager(EmbeddingManager):
def __init__( def __init__(
self, self,
emo_embeddings_file_path: str = "", embeddings_file_path: str = "",
emotion_id_file_path: str = "", emotion_id_file_path: str = "",
encoder_model_path: str = "", encoder_model_path: str = "",
encoder_config_path: str = "", encoder_config_path: str = "",
use_cuda: bool = False, use_cuda: bool = False,
): ):
super().__init__( super().__init__(
external_emotions_ids_file_path=emo_embeddings_file_path, embedding_file_path=embeddings_file_path,
id_file_path=emotion_id_file_path, id_file_path=emotion_id_file_path,
encoder_model_path=encoder_model_path, encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path, encoder_config_path=encoder_config_path,
@ -98,11 +96,15 @@ class EmotionManager(EmbeddingManager):
emotion_manager = EmotionManager( emotion_manager = EmotionManager(
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None) emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
) )
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
if get_from_config_or_model_args_with_default(config, "use_external_emotion_embedding", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None):
emotion_manager = EmotionManager( emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None) embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
) )
return emotion_manager return emotion_manager
@ -157,26 +159,26 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if c.use_external_emotions_embeddings: if c.use_external_emotions_embeddings:
# restore emotion manager with the embedding file # restore emotion manager with the embedding file
if not os.path.exists(emotions_ids_file): if not os.path.exists(emotions_ids_file):
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_ids_file") print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file")
if not os.path.exists(c.external_emotions_ids_file): if not os.path.exists(c.external_emotions_embs_file):
raise RuntimeError( raise RuntimeError(
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_ids_file" "You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"
) )
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
emotion_manager.load_embeddings_from_file(emotions_ids_file) emotion_manager.load_embeddings_from_file(emotions_ids_file)
elif not c.use_external_emotions_embeddings: # restor emotion manager with emotion ID file. elif not c.use_external_emotions_embeddings: # restor emotion manager with emotion ID file.
emotion_manager.load_ids_from_file(emotions_ids_file) emotion_manager.load_ids_from_file(emotions_ids_file)
elif c.use_external_emotions_embeddings and c.external_emotions_ids_file: elif c.use_external_emotions_embeddings and c.external_emotions_embs_file:
# new emotion manager with external emotion embeddings. # new emotion manager with external emotion embeddings.
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
elif c.use_external_emotions_embeddings and not c.external_emotions_ids_file: elif c.use_external_emotions_embeddings and not c.external_emotions_embs_file:
raise "use_external_emotions_embeddings is True, so you need pass a external emotion embedding file." raise "use_external_emotions_embeddings is True, so you need pass a external emotion embedding file."
elif c.use_emotion_embedding: elif c.use_emotion_embedding:
if "emotions_ids_file" in c and c.emotions_ids_file: if "emotions_ids_file" in c and c.emotions_ids_file:
emotion_manager.load_ids_from_file(c.emotions_ids_file) emotion_manager.load_ids_from_file(c.emotions_ids_file)
else: # enable get ids from eternal embedding files else: # enable get ids from eternal embedding files
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
if emotion_manager.num_emotions > 0: if emotion_manager.num_emotions > 0:
print( print(
@ -189,9 +191,8 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if out_path: if out_path:
out_file_path = os.path.join(out_path, "emotions.json") out_file_path = os.path.join(out_path, "emotions.json")
print(f" > Saving `emotions.json` to {out_file_path}.") print(f" > Saving `emotions.json` to {out_file_path}.")
if c.use_external_emotions_embeddings and c.external_emotions_ids_file: if c.use_external_emotions_embeddings and c.external_emotions_embs_file:
emotion_manager.save_embeddings_to_file(out_file_path) emotion_manager.save_embeddings_to_file(out_file_path)
else: else:
emotion_manager.save_ids_to_file(out_file_path) emotion_manager.save_ids_to_file(out_file_path)
return emotion_manager return emotion_manager

View File

@ -29,6 +29,8 @@ def run_model_torch(
style_text: str = None, style_text: str = None,
d_vector: torch.Tensor = None, d_vector: torch.Tensor = None,
language_id: torch.Tensor = None, language_id: torch.Tensor = None,
emotion_id: torch.Tensor = None,
emotion_embedding: torch.Tensor = None,
) -> Dict: ) -> Dict:
"""Run a torch model for inference. It does not support batch inference. """Run a torch model for inference. It does not support batch inference.
@ -56,6 +58,8 @@ def run_model_torch(
"style_mel": style_mel, "style_mel": style_mel,
"style_text": style_text, "style_text": style_text,
"language_ids": language_id, "language_ids": language_id,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
}, },
) )
return outputs return outputs
@ -122,6 +126,8 @@ def synthesis(
do_trim_silence=False, do_trim_silence=False,
d_vector=None, d_vector=None,
language_id=None, language_id=None,
emotion_id=None,
emotion_embedding=None
): ):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model. the vocoder model.
@ -190,6 +196,12 @@ def synthesis(
if language_id is not None: if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda) language_id = id_to_torch(language_id, cuda=use_cuda)
if emotion_id is not None:
emotion_id = id_to_torch(emotion_id, cuda=use_cuda)
if emotion_embedding is not None:
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
if not isinstance(style_mel, dict): if not isinstance(style_mel, dict):
# GST or Capacitron style mel # GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
@ -212,6 +224,8 @@ def synthesis(
style_text, style_text,
d_vector=d_vector, d_vector=d_vector,
language_id=language_id, language_id=language_id,
emotion_id=emotion_id,
emotion_embedding=emotion_embedding,
) )
model_outputs = outputs["model_outputs"] model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy() model_outputs = model_outputs[0].data.cpu().numpy()

View File

@ -22,6 +22,7 @@ class Synthesizer(object):
tts_checkpoint: str, tts_checkpoint: str,
tts_config_path: str, tts_config_path: str,
tts_speakers_file: str = "", tts_speakers_file: str = "",
tts_emotions_file: str = "",
tts_languages_file: str = "", tts_languages_file: str = "",
vocoder_checkpoint: str = "", vocoder_checkpoint: str = "",
vocoder_config: str = "", vocoder_config: str = "",
@ -52,6 +53,7 @@ class Synthesizer(object):
self.tts_checkpoint = tts_checkpoint self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file self.tts_speakers_file = tts_speakers_file
self.tts_emotions_file = tts_emotions_file
self.tts_languages_file = tts_languages_file self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config self.vocoder_config = vocoder_config
@ -183,6 +185,7 @@ class Synthesizer(object):
style_text=None, style_text=None,
reference_wav=None, reference_wav=None,
reference_speaker_name=None, reference_speaker_name=None,
emotion_name=None,
) -> List[int]: ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech. """🐸 TTS magic. Run all the models and generate speech.
@ -240,7 +243,7 @@ class Synthesizer(object):
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
) )
# handle multi-lingaul # handle multi-lingual
language_id = None language_id = None
if self.tts_languages_file or ( if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
@ -260,6 +263,29 @@ class Synthesizer(object):
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
) )
# handle emotion
emotion_embedding, emotion_id = None, None
if self.tts_emotions_file or hasattr(self.tts_model.emotion_manager, "ids"):
if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False):
# get the average speaker embedding from the saved embeddings.
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.emotion_manager.ids[emotion_name]
elif not emotion_name:
raise ValueError(
" [!] Look like you use an emotion model. "
"You need to define either a `emotion_name` to use an emotion model."
)
else:
if emotion_name:
raise ValueError(
f" [!] Missing emotion.json file path for selecting the emotion {emotion_name}."
"Define path for emotion.json if it is an emotion model or remove defined emotion idx. "
)
# compute a new d_vector from the given clip. # compute a new d_vector from the given clip.
if speaker_wav is not None: if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
@ -280,6 +306,8 @@ class Synthesizer(object):
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
d_vector=speaker_embedding, d_vector=speaker_embedding,
language_id=language_id, language_id=language_id,
emotion_embedding=emotion_embedding,
emotion_id=emotion_id,
) )
waveform = outputs["wav"] waveform = outputs["wav"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()

View File

@ -0,0 +1,83 @@
import glob
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = None
config.model_args.d_vector_dim = 256
# emotion
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)