mirror of https://github.com/coqui-ai/TTS.git
Fix Style tests
This commit is contained in:
parent
46762ccf35
commit
01dd4e4051
|
@ -1,8 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
|
@ -30,8 +30,10 @@ parser.add_argument(
|
||||||
help="Path to dataset config file.",
|
help="Path to dataset config file.",
|
||||||
)
|
)
|
||||||
parser.add_argument("output_path", type=str, help="path for output .json file.")
|
parser.add_argument("output_path", type=str, help="path for output .json file.")
|
||||||
|
parser.add_argument("--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None)
|
||||||
|
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None
|
"--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False
|
||||||
)
|
)
|
||||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda. Default False", default=False)
|
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda. Default False", default=False)
|
||||||
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
|
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
|
||||||
|
@ -75,7 +77,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
||||||
|
|
||||||
if args.use_predicted_label:
|
if args.use_predicted_label:
|
||||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None)
|
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
||||||
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
|
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
|
||||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||||
if encoder_manager.use_cuda:
|
if encoder_manager.use_cuda:
|
||||||
|
@ -84,9 +86,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
|
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
|
||||||
class_name = map_classid_to_classname[str(class_id)]
|
class_name = map_classid_to_classname[str(class_id)]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(" [!] use_predicted_label is enable and predicted_labels is not available !!")
|
||||||
" [!] use_predicted_label is enable and predicted_labels is not available !!"
|
|
||||||
)
|
|
||||||
|
|
||||||
# create class_mapping if target dataset is defined
|
# create class_mapping if target dataset is defined
|
||||||
class_mapping[wav_file_name] = {}
|
class_mapping[wav_file_name] = {}
|
||||||
|
|
|
@ -13,7 +13,6 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||||
|
|
||||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
||||||
|
|
||||||
class_acc_dict = {}
|
class_acc_dict = {}
|
||||||
|
|
||||||
# compute embeddings for all wav_files
|
# compute embeddings for all wav_files
|
||||||
|
|
|
@ -433,7 +433,10 @@ class BaseTTS(BaseTrainerModel):
|
||||||
output_path = os.path.join(trainer.output_path, "emotions.json")
|
output_path = os.path.join(trainer.output_path, "emotions.json")
|
||||||
|
|
||||||
if hasattr(trainer.config, "model_args"):
|
if hasattr(trainer.config, "model_args"):
|
||||||
if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file:
|
if (
|
||||||
|
trainer.config.model_args.use_emotion_embedding
|
||||||
|
and not trainer.config.model_args.external_emotions_embs_file
|
||||||
|
):
|
||||||
self.emotion_manager.save_ids_to_file(output_path)
|
self.emotion_manager.save_ids_to_file(output_path)
|
||||||
trainer.config.model_args.emotions_ids_file = output_path
|
trainer.config.model_args.emotions_ids_file = output_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,10 +22,10 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.emotions import EmotionManager
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.emotions import EmotionManager
|
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
@ -712,8 +712,8 @@ class Vits(BaseTTS):
|
||||||
def init_consistency_loss(self):
|
def init_consistency_loss(self):
|
||||||
if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
|
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
if self.speaker_manager.encoder is None and (
|
if self.speaker_manager.encoder is None and (
|
||||||
|
@ -849,8 +849,14 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid,
|
return {
|
||||||
"emotion_embeddings": eg, "emotion_ids": eid}
|
"speaker_ids": sid,
|
||||||
|
"style_wav": None,
|
||||||
|
"d_vectors": g,
|
||||||
|
"language_ids": lid,
|
||||||
|
"emotion_embeddings": eg,
|
||||||
|
"emotion_ids": eid,
|
||||||
|
}
|
||||||
|
|
||||||
def _freeze_layers(self):
|
def _freeze_layers(self):
|
||||||
if self.args.freeze_encoder:
|
if self.args.freeze_encoder:
|
||||||
|
@ -979,7 +985,13 @@ class Vits(BaseTTS):
|
||||||
y: torch.tensor,
|
y: torch.tensor,
|
||||||
y_lengths: torch.tensor,
|
y_lengths: torch.tensor,
|
||||||
waveform: torch.tensor,
|
waveform: torch.tensor,
|
||||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None},
|
aux_input={
|
||||||
|
"d_vectors": None,
|
||||||
|
"speaker_ids": None,
|
||||||
|
"language_ids": None,
|
||||||
|
"emotion_embeddings": None,
|
||||||
|
"emotion_ids": None,
|
||||||
|
},
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Forward pass of the model.
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
@ -1033,7 +1045,7 @@ class Vits(BaseTTS):
|
||||||
if g is None:
|
if g is None:
|
||||||
g = eg
|
g = eg
|
||||||
else:
|
else:
|
||||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb = None
|
lang_emb = None
|
||||||
|
@ -1071,7 +1083,9 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
|
||||||
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
|
encoder = (
|
||||||
|
self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
|
||||||
|
)
|
||||||
# concate generated and GT waveforms
|
# concate generated and GT waveforms
|
||||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||||
|
|
||||||
|
@ -1116,7 +1130,16 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}
|
self,
|
||||||
|
x,
|
||||||
|
aux_input={
|
||||||
|
"x_lengths": None,
|
||||||
|
"d_vectors": None,
|
||||||
|
"speaker_ids": None,
|
||||||
|
"language_ids": None,
|
||||||
|
"emotion_embeddings": None,
|
||||||
|
"emotion_ids": None,
|
||||||
|
},
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Note:
|
Note:
|
||||||
|
@ -1152,7 +1175,7 @@ class Vits(BaseTTS):
|
||||||
if g is None:
|
if g is None:
|
||||||
g = eg
|
g = eg
|
||||||
else:
|
else:
|
||||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb = None
|
lang_emb = None
|
||||||
|
@ -1297,8 +1320,13 @@ class Vits(BaseTTS):
|
||||||
spec,
|
spec,
|
||||||
spec_lens,
|
spec_lens,
|
||||||
waveform,
|
waveform,
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids,
|
aux_input={
|
||||||
"emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids},
|
"d_vectors": d_vectors,
|
||||||
|
"speaker_ids": speaker_ids,
|
||||||
|
"language_ids": language_ids,
|
||||||
|
"emotion_embeddings": emotion_embeddings,
|
||||||
|
"emotion_ids": emotion_ids,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# cache tensors for the generator pass
|
# cache tensors for the generator pass
|
||||||
|
@ -1362,7 +1390,8 @@ class Vits(BaseTTS):
|
||||||
feats_disc_fake=feats_disc_fake,
|
feats_disc_fake=feats_disc_fake,
|
||||||
feats_disc_real=feats_disc_real,
|
feats_disc_real=feats_disc_real,
|
||||||
loss_duration=self.model_outputs_cache["loss_duration"],
|
loss_duration=self.model_outputs_cache["loss_duration"],
|
||||||
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss,
|
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss
|
||||||
|
or self.args.use_emotion_encoder_as_loss,
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
)
|
)
|
||||||
|
@ -1464,7 +1493,9 @@ class Vits(BaseTTS):
|
||||||
if emotion_name is None:
|
if emotion_name is None:
|
||||||
emotion_embedding = self.emotion_manager.get_random_embeddings()
|
emotion_embedding = self.emotion_manager.get_random_embeddings()
|
||||||
else:
|
else:
|
||||||
emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
|
emotion_embedding = self.emotion_manager.get_mean_embedding(
|
||||||
|
emotion_name, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
elif config.use_emotion_embedding:
|
elif config.use_emotion_embedding:
|
||||||
if emotion_name is None:
|
if emotion_name is None:
|
||||||
emotion_id = self.emotion_manager.get_random_id()
|
emotion_id = self.emotion_manager.get_random_id()
|
||||||
|
@ -1479,7 +1510,7 @@ class Vits(BaseTTS):
|
||||||
"language_id": language_id,
|
"language_id": language_id,
|
||||||
"language_name": language_name,
|
"language_name": language_name,
|
||||||
"emotion_embedding": emotion_embedding,
|
"emotion_embedding": emotion_embedding,
|
||||||
"emotion_ids": emotion_id
|
"emotion_ids": emotion_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -1551,7 +1582,11 @@ class Vits(BaseTTS):
|
||||||
language_ids = torch.LongTensor(language_ids)
|
language_ids = torch.LongTensor(language_ids)
|
||||||
|
|
||||||
# get emotion embedding
|
# get emotion embedding
|
||||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings:
|
if (
|
||||||
|
self.emotion_manager is not None
|
||||||
|
and self.emotion_manager.embeddings
|
||||||
|
and self.args.use_external_emotions_embeddings
|
||||||
|
):
|
||||||
emotion_mapping = self.emotion_manager.embeddings
|
emotion_mapping = self.emotion_manager.embeddings
|
||||||
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||||
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
||||||
|
@ -1783,13 +1818,9 @@ class Vits(BaseTTS):
|
||||||
emotion_manager = EmotionManager.init_from_config(config)
|
emotion_manager = EmotionManager.init_from_config(config)
|
||||||
|
|
||||||
if config.model_args.encoder_model_path and speaker_manager is not None:
|
if config.model_args.encoder_model_path and speaker_manager is not None:
|
||||||
speaker_manager.init_encoder(
|
speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
||||||
config.model_args.encoder_model_path, config.model_args.encoder_config_path
|
|
||||||
)
|
|
||||||
elif config.model_args.encoder_model_path and emotion_manager is not None:
|
elif config.model_args.encoder_model_path and emotion_manager is not None:
|
||||||
emotion_manager.init_encoder(
|
emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
||||||
config.model_args.encoder_model_path, config.model_args.encoder_config_path
|
|
||||||
)
|
|
||||||
|
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from coqpit import Coqpit
|
||||||
from TTS.config import get_from_config_or_model_args_with_default
|
from TTS.config import get_from_config_or_model_args_with_default
|
||||||
from TTS.tts.utils.managers import EmbeddingManager
|
from TTS.tts.utils.managers import EmbeddingManager
|
||||||
|
|
||||||
|
|
||||||
class EmotionManager(EmbeddingManager):
|
class EmotionManager(EmbeddingManager):
|
||||||
"""Manage the emotions for emotional TTS. Load a datafile and parse the information
|
"""Manage the emotions for emotional TTS. Load a datafile and parse the information
|
||||||
in a way that can be queried by emotion or clip.
|
in a way that can be queried by emotion or clip.
|
||||||
|
@ -59,8 +60,8 @@ class EmotionManager(EmbeddingManager):
|
||||||
id_file_path=emotion_id_file_path,
|
id_file_path=emotion_id_file_path,
|
||||||
encoder_model_path=encoder_model_path,
|
encoder_model_path=encoder_model_path,
|
||||||
encoder_config_path=encoder_config_path,
|
encoder_config_path=encoder_config_path,
|
||||||
use_cuda=use_cuda
|
use_cuda=use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_emotions(self):
|
def num_emotions(self):
|
||||||
|
@ -98,13 +99,17 @@ class EmotionManager(EmbeddingManager):
|
||||||
)
|
)
|
||||||
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
|
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||||
|
config, "external_emotions_embs_file", None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
|
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
|
||||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
|
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||||
|
config, "external_emotions_embs_file", None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return emotion_manager
|
return emotion_manager
|
||||||
|
@ -159,7 +164,9 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
|
||||||
if c.use_external_emotions_embeddings:
|
if c.use_external_emotions_embeddings:
|
||||||
# restore emotion manager with the embedding file
|
# restore emotion manager with the embedding file
|
||||||
if not os.path.exists(emotions_ids_file):
|
if not os.path.exists(emotions_ids_file):
|
||||||
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file")
|
print(
|
||||||
|
"WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file"
|
||||||
|
)
|
||||||
if not os.path.exists(c.external_emotions_embs_file):
|
if not os.path.exists(c.external_emotions_embs_file):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"
|
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"
|
||||||
|
@ -177,7 +184,7 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
|
||||||
elif c.use_emotion_embedding:
|
elif c.use_emotion_embedding:
|
||||||
if "emotions_ids_file" in c and c.emotions_ids_file:
|
if "emotions_ids_file" in c and c.emotions_ids_file:
|
||||||
emotion_manager.load_ids_from_file(c.emotions_ids_file)
|
emotion_manager.load_ids_from_file(c.emotions_ids_file)
|
||||||
else: # enable get ids from eternal embedding files
|
else: # enable get ids from eternal embedding files
|
||||||
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
|
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
|
||||||
|
|
||||||
if emotion_manager.num_emotions > 0:
|
if emotion_manager.num_emotions > 0:
|
||||||
|
|
|
@ -127,7 +127,7 @@ def synthesis(
|
||||||
d_vector=None,
|
d_vector=None,
|
||||||
language_id=None,
|
language_id=None,
|
||||||
emotion_id=None,
|
emotion_id=None,
|
||||||
emotion_embedding=None
|
emotion_embedding=None,
|
||||||
):
|
):
|
||||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||||
the vocoder model.
|
the vocoder model.
|
||||||
|
|
|
@ -123,26 +123,42 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None:
|
if (
|
||||||
|
self.encoder_checkpoint
|
||||||
|
and hasattr(self.tts_model, "speaker_manager")
|
||||||
|
and self.tts_model.speaker_manager is not None
|
||||||
|
):
|
||||||
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda)
|
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda)
|
||||||
|
|
||||||
if self.tts_emotions_file and hasattr(self.tts_model, "emotion_manager") and self.tts_model.emotion_manager is not None:
|
if (
|
||||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)):
|
self.tts_emotions_file
|
||||||
|
and hasattr(self.tts_model, "emotion_manager")
|
||||||
|
and self.tts_model.emotion_manager is not None
|
||||||
|
):
|
||||||
|
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||||
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
|
||||||
|
):
|
||||||
self.tts_model.emotion_manager.load_embeddings_from_file(self.tts_emotions_file)
|
self.tts_model.emotion_manager.load_embeddings_from_file(self.tts_emotions_file)
|
||||||
else:
|
else:
|
||||||
self.tts_model.emotion_manager.load_ids_from_file(self.tts_emotions_file)
|
self.tts_model.emotion_manager.load_ids_from_file(self.tts_emotions_file)
|
||||||
|
|
||||||
if self.tts_speakers_file and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None:
|
if (
|
||||||
if getattr(self.tts_config, "use_d_vector_file", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_d_vector_file", False)):
|
self.tts_speakers_file
|
||||||
|
and hasattr(self.tts_model, "speaker_manager")
|
||||||
|
and self.tts_model.speaker_manager is not None
|
||||||
|
):
|
||||||
|
if getattr(self.tts_config, "use_d_vector_file", False) or (
|
||||||
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
and getattr(self.tts_config.model_args, "use_d_vector_file", False)
|
||||||
|
):
|
||||||
self.tts_model.speaker_manager.load_embeddings_from_file(self.tts_speakers_file)
|
self.tts_model.speaker_manager.load_embeddings_from_file(self.tts_speakers_file)
|
||||||
else:
|
else:
|
||||||
self.tts_model.speaker_manager.load_ids_from_file(self.tts_speakers_file)
|
self.tts_model.speaker_manager.load_ids_from_file(self.tts_speakers_file)
|
||||||
|
|
||||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||||
if hasattr(self.tts_config, "model_args") and hasattr(
|
if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "encoder_config_path"):
|
||||||
self.tts_config.model_args, "encoder_config_path"
|
|
||||||
):
|
|
||||||
self.encoder_checkpoint = self.tts_config.model_args.encoder_model_path
|
self.encoder_checkpoint = self.tts_config.model_args.encoder_model_path
|
||||||
self.encoder_config = self.tts_config.model_args.encoder_config_path
|
self.encoder_config = self.tts_config.model_args.encoder_config_path
|
||||||
|
|
||||||
|
@ -277,11 +293,18 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# handle emotion
|
# handle emotion
|
||||||
emotion_embedding, emotion_id = None, None
|
emotion_embedding, emotion_id = None, None
|
||||||
if self.tts_emotions_file or (getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)):
|
if self.tts_emotions_file or (
|
||||||
|
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||||
|
):
|
||||||
if emotion_name and isinstance(emotion_name, str):
|
if emotion_name and isinstance(emotion_name, str):
|
||||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)):
|
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||||
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
|
||||||
|
):
|
||||||
# get the average speaker embedding from the saved embeddings.
|
# get the average speaker embedding from the saved embeddings.
|
||||||
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
|
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(
|
||||||
|
emotion_name, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
|
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
|
|
Loading…
Reference in New Issue