Add Emotion Support for the VITS model

This commit is contained in:
Edresson Casanova 2022-03-15 01:16:48 +00:00
parent 18d3565d37
commit e3520e9e9f
7 changed files with 306 additions and 48 deletions

View File

@ -152,6 +152,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--emotions_file_path", type=str, help="JSON file for emotion model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
parser.add_argument(
"--speaker_idx",
@ -165,6 +166,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
help="Target language ID for a multi-lingual TTS model.",
default=None,
)
parser.add_argument(
"--emotion_idx",
type=str,
help="Target emotion ID.",
default=None,
)
parser.add_argument(
"--speaker_wav",
nargs="+",
@ -244,6 +251,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path = args.model_path
config_path = args.config_path
speakers_file_path = args.speakers_file_path
emotions_file_path = args.emotions_file_path
language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None:
@ -259,6 +267,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path,
config_path,
speakers_file_path,
emotions_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
@ -296,7 +305,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(" > Text: {}".format(args.text))
# kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx)
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx, emotion_name=args.emotion_idx)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -423,3 +423,25 @@ class BaseTTS(BaseTrainerModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
if hasattr(self, "emotion_manager") and self.emotion_manager is not None:
output_path = os.path.join(trainer.output_path, "emotions.json")
if hasattr(trainer.config, "model_args"):
if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file:
self.emotion_manager.save_ids_to_file(output_path)
trainer.config.model_args.emotions_ids_file = output_path
else:
self.emotion_manager.save_embeddings_to_file(output_path)
trainer.config.model_args.external_emotions_embs_file = output_path
else:
if trainer.config.use_emotion_embedding and not trainer.config.external_emotions_embs_file:
self.emotion_manager.save_ids_to_file(output_path)
trainer.config.emotions_ids_file = output_path
else:
self.emotion_manager.save_embeddings_to_file(output_path)
trainer.config.external_emotions_embs_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `emotions.json` is saved to {output_path}.")
print(" > `emotions_ids_file` or `external_emotions_embs_file` is updated in the config.json.")

View File

@ -25,6 +25,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -278,7 +279,7 @@ class VitsDataset(TTSDataset):
"waveform_lens": wav_lens, # (B)
"waveform_rel_lens": wav_rel_lens,
"speaker_names": batch["speaker_name"],
"language_names": batch["language_name"],
"f": batch["language_name"],
"audio_files": batch["wav_file"],
"raw_text": batch["raw_text"],
}
@ -491,6 +492,15 @@ class VitsArgs(Coqpit):
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0
# use emotion embeddings
use_emotion_embedding: bool = False
use_external_emotions_embeddings: bool = False
emotions_ids_file: str = None
external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0
num_emotions: int = 0
detach_dp_input: bool = True
use_language_embedding: bool = False
embedded_language_dim: int = 4
@ -542,12 +552,13 @@ class Vits(BaseTTS):
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
emotion_manager: EmotionManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
self.init_multispeaker(config)
self.init_multilingual(config)
self.init_emotion(config, emotion_manager)
self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale
@ -576,7 +587,7 @@ class Vits(BaseTTS):
kernel_size=self.args.kernel_size_posterior_encoder,
dilation_rate=self.args.dilation_rate_posterior_encoder,
num_layers=self.args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.cond_embedding_dim,
)
self.flow = ResidualCouplingBlocks(
@ -585,7 +596,7 @@ class Vits(BaseTTS):
kernel_size=self.args.kernel_size_flow,
dilation_rate=self.args.dilation_rate_flow,
num_layers=self.args.num_layers_flow,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.cond_embedding_dim,
)
if self.args.use_sdp:
@ -595,7 +606,7 @@ class Vits(BaseTTS):
3,
self.args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
cond_channels=self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim,
)
else:
@ -604,7 +615,7 @@ class Vits(BaseTTS):
256,
3,
self.args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.cond_embedding_dim,
language_emb_dim=self.embedded_language_dim,
)
@ -618,7 +629,7 @@ class Vits(BaseTTS):
self.args.upsample_initial_channel_decoder,
self.args.upsample_rates_decoder,
inference_padding=0,
cond_channels=self.embedded_speaker_dim,
cond_channels=self.cond_embedding_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
@ -637,7 +648,7 @@ class Vits(BaseTTS):
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
self.cond_embedding_dim = 0
self.num_speakers = self.args.num_speakers
self.audio_transform = None
@ -680,14 +691,14 @@ class Vits(BaseTTS):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.cond_embedding_dim += self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.args.speaker_embedding_channels)
def _init_d_vector(self):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.embedded_speaker_dim = self.args.d_vector_dim
self.cond_embedding_dim += self.args.d_vector_dim
def init_multilingual(self, config: Coqpit):
"""Initialize multilingual modules of a model.
@ -708,9 +719,36 @@ class Vits(BaseTTS):
self.embedded_language_dim = 0
self.emb_l = None
def init_emotion(self, config: Coqpit, emotion_manager: EmotionManager):
# pylint: disable=attribute-defined-outside-init
"""Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer
or with external `embeddings` computed from a emotion encoder model.
You must provide a `emotion_manager` at initialization to set up the emotion modules.
Args:
config (Coqpit): Model configuration.
emotion_manager (Coqpit): Emotion Manager.
"""
self.emotion_manager = emotion_manager
self.num_emotions = self.args.num_emotions
if self.emotion_manager:
self.num_emotions = self.emotion_manager.num_emotions
if self.args.use_emotion_embedding:
if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid,
"emotion_embeddings": eg, "emotion_ids": eid}
def _freeze_layers(self):
if self.args.freeze_encoder:
@ -740,7 +778,7 @@ class Vits(BaseTTS):
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid = None, None, None
sid, g, lid, eid, eg = None, None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
@ -755,7 +793,17 @@ class Vits(BaseTTS):
if lid.ndim == 0:
lid = lid.unsqueeze_(0)
return sid, g, lid
if "emotion_ids" in aux_input and aux_input["emotion_ids"] is not None:
eid = aux_input["emotion_ids"]
if eid.ndim == 0:
eid = eid.unsqueeze_(0)
if "emotion_embeddings" in aux_input and aux_input["emotion_embeddings"] is not None:
eg = F.normalize(aux_input["emotion_embeddings"]).unsqueeze(-1)
if eg.ndim == 2:
eg = eg.unsqueeze_(0)
return sid, g, lid, eid, eg
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
@ -812,7 +860,7 @@ class Vits(BaseTTS):
y: torch.tensor,
y_lengths: torch.tensor,
waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None},
) -> Dict:
"""Forward pass of the model.
@ -852,11 +900,19 @@ class Vits(BaseTTS):
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
"""
outputs = {}
sid, g, lid = self._set_cond_input(aux_input)
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -929,7 +985,7 @@ class Vits(BaseTTS):
return torch.tensor(x.shape[1:2]).to(x.device)
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}
): # pylint: disable=dangerous-default-value
"""
Note:
@ -949,13 +1005,21 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
sid, g, lid = self._set_cond_input(aux_input)
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1)
# emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -1067,6 +1131,8 @@ class Vits(BaseTTS):
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"]
emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"]
# generator pass
@ -1076,7 +1142,8 @@ class Vits(BaseTTS):
spec,
spec_lens,
waveform,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids,
"emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids},
)
# cache tensors for the generator pass
@ -1196,7 +1263,7 @@ class Vits(BaseTTS):
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
@ -1207,11 +1274,13 @@ class Vits(BaseTTS):
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
elif len(sentence_info) == 5:
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
@ -1228,6 +1297,19 @@ class Vits(BaseTTS):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.ids[language_name]
# get emotion id/embedding
if hasattr(self, "emotion_manager"):
if config.use_external_emotions_embeddings:
if emotion_name is None:
emotion_embedding = self.emotion_manager.get_random_embeddings()
else:
emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
elif config.use_emotion_embedding:
if emotion_name is None:
emotion_id = self.emotion_manager.get_random_id()
else:
emotion_id = self.emotion_manager.ids[emotion_name]
return {
"text": text,
"speaker_id": speaker_id,
@ -1235,6 +1317,8 @@ class Vits(BaseTTS):
"d_vector": d_vector,
"language_id": language_id,
"language_name": language_name,
"emotion_embedding": emotion_embedding,
"emotion_ids": emotion_id
}
@torch.no_grad()
@ -1261,6 +1345,8 @@ class Vits(BaseTTS):
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"],
emotion_embedding=aux_inputs["emotion_embedding"],
emotion_id=aux_inputs["emotion_ids"],
use_griffin_lim=True,
do_trim_silence=False,
).values()
@ -1275,10 +1361,12 @@ class Vits(BaseTTS):
logger.test_figures(steps, outputs["figures"])
def format_batch(self, batch: Dict) -> Dict:
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
"""Compute speaker, langugage IDs, d_vector and emotion embeddings for the batch if necessary."""
speaker_ids = None
language_ids = None
d_vectors = None
emotion_embeddings = None
emotion_ids = None
# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
@ -1300,14 +1388,28 @@ class Vits(BaseTTS):
and self.language_manager.ids
and self.args.use_language_embedding
):
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
language_ids = [self.language_manager.ids[ln] for ln in batch["f"]]
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
# get emotion embedding
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings:
emotion_mapping = self.emotion_manager.embeddings
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding:
emotion_mapping = self.emotion_manager.embeddings
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
emotion_ids = torch.LongTensor(emotion_ids)
batch["language_ids"] = language_ids
batch["d_vectors"] = d_vectors
batch["speaker_ids"] = speaker_ids
batch["emotion_embeddings"] = emotion_embeddings
batch["emotion_ids"] = emotion_ids
return batch
def format_batch_on_device(self, batch):
@ -1480,12 +1582,13 @@ class Vits(BaseTTS):
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config)
emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path:
speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
##################################

View File

@ -1,10 +1,8 @@
import json
import os
from typing import Any, Dict, List, Tuple, Union
from typing import Any, List
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default
@ -34,7 +32,7 @@ class EmotionManager(EmbeddingManager):
computes the embeddings for a given clip or emotion.
Args:
emo_embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
emotion_id_file_path (str, optional): Path to the metafile that maps emotion names to ids used by
TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the emotion encoder model file. Defaults to "".
@ -50,14 +48,14 @@ class EmotionManager(EmbeddingManager):
def __init__(
self,
emo_embeddings_file_path: str = "",
embeddings_file_path: str = "",
emotion_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
use_cuda: bool = False,
):
super().__init__(
external_emotions_ids_file_path=emo_embeddings_file_path,
embedding_file_path=embeddings_file_path,
id_file_path=emotion_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
@ -98,11 +96,15 @@ class EmotionManager(EmbeddingManager):
emotion_manager = EmotionManager(
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotion_embedding", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None):
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None)
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
)
return emotion_manager
@ -157,26 +159,26 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if c.use_external_emotions_embeddings:
# restore emotion manager with the embedding file
if not os.path.exists(emotions_ids_file):
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_ids_file")
if not os.path.exists(c.external_emotions_ids_file):
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file")
if not os.path.exists(c.external_emotions_embs_file):
raise RuntimeError(
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_ids_file"
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"
)
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file)
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
emotion_manager.load_embeddings_from_file(emotions_ids_file)
elif not c.use_external_emotions_embeddings: # restor emotion manager with emotion ID file.
emotion_manager.load_ids_from_file(emotions_ids_file)
elif c.use_external_emotions_embeddings and c.external_emotions_ids_file:
elif c.use_external_emotions_embeddings and c.external_emotions_embs_file:
# new emotion manager with external emotion embeddings.
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file)
elif c.use_external_emotions_embeddings and not c.external_emotions_ids_file:
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
elif c.use_external_emotions_embeddings and not c.external_emotions_embs_file:
raise "use_external_emotions_embeddings is True, so you need pass a external emotion embedding file."
elif c.use_emotion_embedding:
if "emotions_ids_file" in c and c.emotions_ids_file:
emotion_manager.load_ids_from_file(c.emotions_ids_file)
else: # enable get ids from eternal embedding files
emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file)
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
if emotion_manager.num_emotions > 0:
print(
@ -189,9 +191,8 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if out_path:
out_file_path = os.path.join(out_path, "emotions.json")
print(f" > Saving `emotions.json` to {out_file_path}.")
if c.use_external_emotions_embeddings and c.external_emotions_ids_file:
if c.use_external_emotions_embeddings and c.external_emotions_embs_file:
emotion_manager.save_embeddings_to_file(out_file_path)
else:
emotion_manager.save_ids_to_file(out_file_path)
return emotion_manager

View File

@ -28,6 +28,8 @@ def run_model_torch(
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
emotion_id: torch.Tensor = None,
emotion_embedding: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -54,6 +56,8 @@ def run_model_torch(
"d_vectors": d_vector,
"style_mel": style_mel,
"language_ids": language_id,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
},
)
return outputs
@ -119,6 +123,8 @@ def synthesis(
do_trim_silence=False,
d_vector=None,
language_id=None,
emotion_id=None,
emotion_embedding=None
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -176,12 +182,18 @@ def synthesis(
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
if emotion_id is not None:
emotion_id = id_to_torch(emotion_id, cuda=use_cuda)
if emotion_embedding is not None:
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id, emotion_id=emotion_id, emotion_embedding=emotion_embedding)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]

View File

@ -22,6 +22,7 @@ class Synthesizer(object):
tts_checkpoint: str,
tts_config_path: str,
tts_speakers_file: str = "",
tts_emotions_file: str = "",
tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
@ -52,6 +53,7 @@ class Synthesizer(object):
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
self.tts_emotions_file = tts_emotions_file
self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
@ -180,6 +182,7 @@ class Synthesizer(object):
style_wav=None,
reference_wav=None,
reference_speaker_name=None,
emotion_name=None,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -234,7 +237,7 @@ class Synthesizer(object):
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
# handle multi-lingaul
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
@ -254,6 +257,29 @@ class Synthesizer(object):
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
)
# handle emotion
emotion_embedding, emotion_id = None, None
if self.tts_emotions_file or hasattr(self.tts_model.emotion_manager, "ids"):
if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False):
# get the average speaker embedding from the saved embeddings.
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.emotion_manager.ids[emotion_name]
elif not emotion_name:
raise ValueError(
" [!] Look like you use an emotion model. "
"You need to define either a `emotion_name` to use an emotion model."
)
else:
if emotion_name:
raise ValueError(
f" [!] Missing emotion.json file path for selecting the emotion {emotion_name}."
"Define path for emotion.json if it is an emotion model or remove defined emotion idx. "
)
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
@ -273,6 +299,8 @@ class Synthesizer(object):
style_wav=style_wav,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
emotion_embedding=emotion_embedding,
emotion_id=emotion_id,
)
waveform = outputs["wav"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()

View File

@ -0,0 +1,83 @@
import glob
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = None
config.model_args.d_vector_dim = 256
# emotion
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)