Fix Style tests

This commit is contained in:
Edresson Casanova 2022-03-30 16:51:39 -03:00
parent aebbdfc62b
commit 047cebd7b8
10 changed files with 128 additions and 60 deletions

View File

@ -1,8 +1,8 @@
import argparse
import os
import torch
from argparse import RawTextHelpFormatter
import torch
from tqdm import tqdm
from TTS.config import load_config
@ -30,11 +30,11 @@ parser.add_argument(
help="Path to dataset config file.",
)
parser.add_argument("output_path", type=str, help="path for output .json file.")
parser.add_argument(
"--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None
)
parser.add_argument("--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False)
parser.add_argument(
"--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False
)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args()
@ -71,7 +71,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
if args.use_predicted_label:
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None)
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda:
@ -80,9 +80,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
class_name = map_classid_to_classname[str(class_id)]
else:
raise RuntimeError(
" [!] use_predicted_label is enable and predicted_labels is not available !!"
)
raise RuntimeError(" [!] use_predicted_label is enable and predicted_labels is not available !!")
# create class_mapping if target dataset is defined
class_mapping[wav_file_name] = {}

View File

@ -12,7 +12,7 @@ from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None)
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
class_acc_dict = {}
# compute embeddings for all wav_files

View File

@ -319,7 +319,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
args.speaker_wav,
reference_wav=args.reference_wav,
reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx
emotion_name=args.emotion_idx,
)
# save the results

View File

@ -422,9 +422,12 @@ class BaseTTS(BaseTrainerModel):
if hasattr(self, "emotion_manager") and self.emotion_manager is not None:
output_path = os.path.join(trainer.output_path, "emotions.json")
if hasattr(trainer.config, "model_args"):
if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file:
if (
trainer.config.model_args.use_emotion_embedding
and not trainer.config.model_args.external_emotions_embs_file
):
self.emotion_manager.save_ids_to_file(output_path)
trainer.config.model_args.emotions_ids_file = output_path
else:
@ -440,4 +443,4 @@ class BaseTTS(BaseTrainerModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `emotions.json` is saved to {output_path}.")
print(" > `emotions_ids_file` or `external_emotions_embs_file` is updated in the config.json.")
print(" > `emotions_ids_file` or `external_emotions_embs_file` is updated in the config.json.")

View File

@ -22,10 +22,10 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -666,8 +666,8 @@ class Vits(BaseTTS):
def init_consistency_loss(self):
if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss:
raise RuntimeError(
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
)
" [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!"
)
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.encoder is None and (
@ -773,8 +773,14 @@ class Vits(BaseTTS):
def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid,
"emotion_embeddings": eg, "emotion_ids": eid}
return {
"speaker_ids": sid,
"style_wav": None,
"d_vectors": g,
"language_ids": lid,
"emotion_embeddings": eg,
"emotion_ids": eid,
}
def _freeze_layers(self):
if self.args.freeze_encoder:
@ -886,7 +892,13 @@ class Vits(BaseTTS):
y: torch.tensor,
y_lengths: torch.tensor,
waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None},
aux_input={
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"emotion_embeddings": None,
"emotion_ids": None,
},
) -> Dict:
"""Forward pass of the model.
@ -940,7 +952,7 @@ class Vits(BaseTTS):
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
@ -974,7 +986,9 @@ class Vits(BaseTTS):
)
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
encoder = (
self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
)
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -1018,7 +1032,16 @@ class Vits(BaseTTS):
return torch.tensor(x.shape[1:2]).to(x.device)
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}
self,
x,
aux_input={
"x_lengths": None,
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"emotion_embeddings": None,
"emotion_ids": None,
},
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1054,7 +1077,7 @@ class Vits(BaseTTS):
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
@ -1187,8 +1210,13 @@ class Vits(BaseTTS):
spec,
spec_lens,
waveform,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids,
"emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids},
aux_input={
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"language_ids": language_ids,
"emotion_embeddings": emotion_embeddings,
"emotion_ids": emotion_ids,
},
)
# cache tensors for the generator pass
@ -1246,7 +1274,8 @@ class Vits(BaseTTS):
feats_disc_fake=feats_disc_fake,
feats_disc_real=feats_disc_real,
loss_duration=self.model_outputs_cache["loss_duration"],
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss,
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
)
@ -1348,14 +1377,15 @@ class Vits(BaseTTS):
if emotion_name is None:
emotion_embedding = self.emotion_manager.get_random_embeddings()
else:
emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
emotion_embedding = self.emotion_manager.get_mean_embedding(
emotion_name, num_samples=None, randomize=False
)
elif config.use_emotion_embedding:
if emotion_name is None:
emotion_id = self.emotion_manager.get_random_id()
else:
emotion_id = self.emotion_manager.ids[emotion_name]
return {
"text": text,
"speaker_id": speaker_id,
@ -1364,7 +1394,7 @@ class Vits(BaseTTS):
"language_id": language_id,
"language_name": language_name,
"emotion_embedding": emotion_embedding,
"emotion_ids": emotion_id
"emotion_ids": emotion_id,
}
@torch.no_grad()
@ -1436,7 +1466,11 @@ class Vits(BaseTTS):
language_ids = torch.LongTensor(language_ids)
# get emotion embedding
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings:
if (
self.emotion_manager is not None
and self.emotion_manager.embeddings
and self.args.use_external_emotions_embeddings
):
emotion_mapping = self.emotion_manager.embeddings
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
@ -1627,13 +1661,9 @@ class Vits(BaseTTS):
emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.encoder_model_path and speaker_manager is not None:
speaker_manager.init_encoder(
config.model_args.encoder_model_path, config.model_args.encoder_config_path
)
speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
elif config.model_args.encoder_model_path and emotion_manager is not None:
emotion_manager.init_encoder(
config.model_args.encoder_model_path, config.model_args.encoder_config_path
)
emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)

View File

@ -8,6 +8,7 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
class EmotionManager(EmbeddingManager):
"""Manage the emotions for emotional TTS. Load a datafile and parse the information
in a way that can be queried by emotion or clip.
@ -59,8 +60,8 @@ class EmotionManager(EmbeddingManager):
id_file_path=emotion_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
use_cuda=use_cuda
)
use_cuda=use_cuda,
)
@property
def num_emotions(self):
@ -98,13 +99,17 @@ class EmotionManager(EmbeddingManager):
)
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
embeddings_file_path=get_from_config_or_model_args_with_default(
config, "external_emotions_embs_file", None
)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None)
embeddings_file_path=get_from_config_or_model_args_with_default(
config, "external_emotions_embs_file", None
)
)
return emotion_manager
@ -159,7 +164,9 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if c.use_external_emotions_embeddings:
# restore emotion manager with the embedding file
if not os.path.exists(emotions_ids_file):
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file")
print(
"WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file"
)
if not os.path.exists(c.external_emotions_embs_file):
raise RuntimeError(
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"
@ -177,7 +184,7 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
elif c.use_emotion_embedding:
if "emotions_ids_file" in c and c.emotions_ids_file:
emotion_manager.load_ids_from_file(c.emotions_ids_file)
else: # enable get ids from eternal embedding files
else: # enable get ids from eternal embedding files
emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file)
if emotion_manager.num_emotions > 0:

View File

@ -1,7 +1,6 @@
import os
from typing import Any, Dict, List
import fsspec
import numpy as np
import torch

View File

@ -65,9 +65,8 @@ class SpeakerManager(EmbeddingManager):
id_file_path=speaker_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
use_cuda=use_cuda
)
use_cuda=use_cuda,
)
if data_items:
self.set_ids_from_data(data_items, parse_key="speaker_name")

View File

@ -124,7 +124,7 @@ def synthesis(
d_vector=None,
language_id=None,
emotion_id=None,
emotion_embedding=None
emotion_embedding=None,
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -193,7 +193,16 @@ def synthesis(
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id, emotion_id=emotion_id, emotion_embedding=emotion_embedding)
outputs = run_model_torch(
model,
text_inputs,
speaker_id,
style_mel,
d_vector=d_vector,
language_id=language_id,
emotion_id=emotion_id,
emotion_embedding=emotion_embedding,
)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]

View File

@ -121,26 +121,42 @@ class Synthesizer(object):
if use_cuda:
self.tts_model.cuda()
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None:
if (
self.encoder_checkpoint
and hasattr(self.tts_model, "speaker_manager")
and self.tts_model.speaker_manager is not None
):
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config)
if self.tts_emotions_file and hasattr(self.tts_model, "emotion_manager") and self.tts_model.emotion_manager is not None:
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)):
if (
self.tts_emotions_file
and hasattr(self.tts_model, "emotion_manager")
and self.tts_model.emotion_manager is not None
):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
):
self.tts_model.emotion_manager.load_embeddings_from_file(self.tts_emotions_file)
else:
self.tts_model.emotion_manager.load_ids_from_file(self.tts_emotions_file)
if self.tts_speakers_file and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None:
if getattr(self.tts_config, "use_d_vector_file", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_d_vector_file", False)):
if (
self.tts_speakers_file
and hasattr(self.tts_model, "speaker_manager")
and self.tts_model.speaker_manager is not None
):
if getattr(self.tts_config, "use_d_vector_file", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_d_vector_file", False)
):
self.tts_model.speaker_manager.load_embeddings_from_file(self.tts_speakers_file)
else:
self.tts_model.speaker_manager.load_ids_from_file(self.tts_speakers_file)
def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders."""
if hasattr(self.tts_config, "model_args") and hasattr(
self.tts_config.model_args, "encoder_config_path"
):
if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "encoder_config_path"):
self.encoder_checkpoint = self.tts_config.model_args.encoder_model_path
self.encoder_config = self.tts_config.model_args.encoder_config_path
@ -273,11 +289,18 @@ class Synthesizer(object):
# handle emotion
emotion_embedding, emotion_id = None, None
if self.tts_emotions_file or (getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)):
if self.tts_emotions_file or (
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
):
if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
):
# get the average speaker embedding from the saved embeddings.
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False)
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(
emotion_name, num_samples=None, randomize=False
)
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name