Fix Style tests

This commit is contained in:
Edresson Casanova 2022-03-30 16:51:39 -03:00
parent aebbdfc62b
commit 047cebd7b8
10 changed files with 128 additions and 60 deletions

View File

@ -1,8 +1,8 @@
import argparse import argparse
import os import os
import torch
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
import torch
from tqdm import tqdm from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
@ -30,11 +30,11 @@ parser.add_argument(
help="Path to dataset config file.", help="Path to dataset config file.",
) )
parser.add_argument("output_path", type=str, help="path for output .json file.") parser.add_argument("output_path", type=str, help="path for output .json file.")
parser.add_argument( parser.add_argument("--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None)
"--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False) parser.add_argument(
"--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False
)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args() args = parser.parse_args()
@ -71,7 +71,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
embedd = encoder_manager.compute_embedding_from_clip(wav_file) embedd = encoder_manager.compute_embedding_from_clip(wav_file)
if args.use_predicted_label: if args.use_predicted_label:
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None) map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0) embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda: if encoder_manager.use_cuda:
@ -80,9 +80,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
class_name = map_classid_to_classname[str(class_id)] class_name = map_classid_to_classname[str(class_id)]
else: else:
raise RuntimeError( raise RuntimeError(" [!] use_predicted_label is enable and predicted_labels is not available !!")
" [!] use_predicted_label is enable and predicted_labels is not available !!"
)
# create class_mapping if target dataset is defined # create class_mapping if target dataset is defined
class_mapping[wav_file_name] = {} class_mapping[wav_file_name] = {}

View File

@ -12,7 +12,7 @@ from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager): def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.encoder_config.class_name_key class_name_key = encoder_manager.encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None) map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
class_acc_dict = {} class_acc_dict = {}
# compute embeddings for all wav_files # compute embeddings for all wav_files

View File

@ -319,7 +319,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
args.speaker_wav, args.speaker_wav,
reference_wav=args.reference_wav, reference_wav=args.reference_wav,
reference_speaker_name=args.reference_speaker_idx, reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx emotion_name=args.emotion_idx,
) )
# save the results # save the results

View File

@ -424,7 +424,10 @@ class BaseTTS(BaseTrainerModel):
output_path = os.path.join(trainer.output_path, "emotions.json") output_path = os.path.join(trainer.output_path, "emotions.json")
if hasattr(trainer.config, "model_args"): if hasattr(trainer.config, "model_args"):
if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file: if (
trainer.config.model_args.use_emotion_embedding
and not trainer.config.model_args.external_emotions_embs_file
):
self.emotion_manager.save_ids_to_file(output_path) self.emotion_manager.save_ids_to_file(output_path)
trainer.config.model_args.emotions_ids_file = output_path trainer.config.model_args.emotions_ids_file = output_path
else: else:

View File

@ -22,10 +22,10 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -773,8 +773,14 @@ class Vits(BaseTTS):
def get_aux_input(self, aux_input: Dict): def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid, return {
"emotion_embeddings": eg, "emotion_ids": eid} "speaker_ids": sid,
"style_wav": None,
"d_vectors": g,
"language_ids": lid,
"emotion_embeddings": eg,
"emotion_ids": eid,
}
def _freeze_layers(self): def _freeze_layers(self):
if self.args.freeze_encoder: if self.args.freeze_encoder:
@ -886,7 +892,13 @@ class Vits(BaseTTS):
y: torch.tensor, y: torch.tensor,
y_lengths: torch.tensor, y_lengths: torch.tensor,
waveform: torch.tensor, waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}, aux_input={
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"emotion_embeddings": None,
"emotion_ids": None,
},
) -> Dict: ) -> Dict:
"""Forward pass of the model. """Forward pass of the model.
@ -974,7 +986,9 @@ class Vits(BaseTTS):
) )
if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss: if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss:
encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder encoder = (
self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder
)
# concate generated and GT waveforms # concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0) wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -1018,7 +1032,16 @@ class Vits(BaseTTS):
return torch.tensor(x.shape[1:2]).to(x.device) return torch.tensor(x.shape[1:2]).to(x.device)
def inference( def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None} self,
x,
aux_input={
"x_lengths": None,
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"emotion_embeddings": None,
"emotion_ids": None,
},
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Note: Note:
@ -1187,8 +1210,13 @@ class Vits(BaseTTS):
spec, spec,
spec_lens, spec_lens,
waveform, waveform,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids, aux_input={
"emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids}, "d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"language_ids": language_ids,
"emotion_embeddings": emotion_embeddings,
"emotion_ids": emotion_ids,
},
) )
# cache tensors for the generator pass # cache tensors for the generator pass
@ -1246,7 +1274,8 @@ class Vits(BaseTTS):
feats_disc_fake=feats_disc_fake, feats_disc_fake=feats_disc_fake,
feats_disc_real=feats_disc_real, feats_disc_real=feats_disc_real,
loss_duration=self.model_outputs_cache["loss_duration"], loss_duration=self.model_outputs_cache["loss_duration"],
use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss, use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
) )
@ -1348,14 +1377,15 @@ class Vits(BaseTTS):
if emotion_name is None: if emotion_name is None:
emotion_embedding = self.emotion_manager.get_random_embeddings() emotion_embedding = self.emotion_manager.get_random_embeddings()
else: else:
emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False) emotion_embedding = self.emotion_manager.get_mean_embedding(
emotion_name, num_samples=None, randomize=False
)
elif config.use_emotion_embedding: elif config.use_emotion_embedding:
if emotion_name is None: if emotion_name is None:
emotion_id = self.emotion_manager.get_random_id() emotion_id = self.emotion_manager.get_random_id()
else: else:
emotion_id = self.emotion_manager.ids[emotion_name] emotion_id = self.emotion_manager.ids[emotion_name]
return { return {
"text": text, "text": text,
"speaker_id": speaker_id, "speaker_id": speaker_id,
@ -1364,7 +1394,7 @@ class Vits(BaseTTS):
"language_id": language_id, "language_id": language_id,
"language_name": language_name, "language_name": language_name,
"emotion_embedding": emotion_embedding, "emotion_embedding": emotion_embedding,
"emotion_ids": emotion_id "emotion_ids": emotion_id,
} }
@torch.no_grad() @torch.no_grad()
@ -1436,7 +1466,11 @@ class Vits(BaseTTS):
language_ids = torch.LongTensor(language_ids) language_ids = torch.LongTensor(language_ids)
# get emotion embedding # get emotion embedding
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings: if (
self.emotion_manager is not None
and self.emotion_manager.embeddings
and self.args.use_external_emotions_embeddings
):
emotion_mapping = self.emotion_manager.embeddings emotion_mapping = self.emotion_manager.embeddings
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]] emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings) emotion_embeddings = torch.FloatTensor(emotion_embeddings)
@ -1627,13 +1661,9 @@ class Vits(BaseTTS):
emotion_manager = EmotionManager.init_from_config(config) emotion_manager = EmotionManager.init_from_config(config)
if config.model_args.encoder_model_path and speaker_manager is not None: if config.model_args.encoder_model_path and speaker_manager is not None:
speaker_manager.init_encoder( speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
config.model_args.encoder_model_path, config.model_args.encoder_config_path
)
elif config.model_args.encoder_model_path and emotion_manager is not None: elif config.model_args.encoder_model_path and emotion_manager is not None:
emotion_manager.init_encoder( emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
config.model_args.encoder_model_path, config.model_args.encoder_config_path
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)

View File

@ -8,6 +8,7 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager from TTS.tts.utils.managers import EmbeddingManager
class EmotionManager(EmbeddingManager): class EmotionManager(EmbeddingManager):
"""Manage the emotions for emotional TTS. Load a datafile and parse the information """Manage the emotions for emotional TTS. Load a datafile and parse the information
in a way that can be queried by emotion or clip. in a way that can be queried by emotion or clip.
@ -59,7 +60,7 @@ class EmotionManager(EmbeddingManager):
id_file_path=emotion_id_file_path, id_file_path=emotion_id_file_path,
encoder_model_path=encoder_model_path, encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path, encoder_config_path=encoder_config_path,
use_cuda=use_cuda use_cuda=use_cuda,
) )
@property @property
@ -98,13 +99,17 @@ class EmotionManager(EmbeddingManager):
) )
elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager( emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None) embeddings_file_path=get_from_config_or_model_args_with_default(
config, "external_emotions_embs_file", None
)
) )
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False): if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager( emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None) embeddings_file_path=get_from_config_or_model_args_with_default(
config, "external_emotions_embs_file", None
)
) )
return emotion_manager return emotion_manager
@ -159,7 +164,9 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
if c.use_external_emotions_embeddings: if c.use_external_emotions_embeddings:
# restore emotion manager with the embedding file # restore emotion manager with the embedding file
if not os.path.exists(emotions_ids_file): if not os.path.exists(emotions_ids_file):
print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file") print(
"WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file"
)
if not os.path.exists(c.external_emotions_embs_file): if not os.path.exists(c.external_emotions_embs_file):
raise RuntimeError( raise RuntimeError(
"You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file" "You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file"

View File

@ -1,7 +1,6 @@
import os import os
from typing import Any, Dict, List from typing import Any, Dict, List
import fsspec import fsspec
import numpy as np import numpy as np
import torch import torch

View File

@ -65,10 +65,9 @@ class SpeakerManager(EmbeddingManager):
id_file_path=speaker_id_file_path, id_file_path=speaker_id_file_path,
encoder_model_path=encoder_model_path, encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path, encoder_config_path=encoder_config_path,
use_cuda=use_cuda use_cuda=use_cuda,
) )
if data_items: if data_items:
self.set_ids_from_data(data_items, parse_key="speaker_name") self.set_ids_from_data(data_items, parse_key="speaker_name")

View File

@ -124,7 +124,7 @@ def synthesis(
d_vector=None, d_vector=None,
language_id=None, language_id=None,
emotion_id=None, emotion_id=None,
emotion_embedding=None emotion_embedding=None,
): ):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model. the vocoder model.
@ -193,7 +193,16 @@ def synthesis(
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0) text_inputs = text_inputs.unsqueeze(0)
# synthesize voice # synthesize voice
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id, emotion_id=emotion_id, emotion_embedding=emotion_embedding) outputs = run_model_torch(
model,
text_inputs,
speaker_id,
style_mel,
d_vector=d_vector,
language_id=language_id,
emotion_id=emotion_id,
emotion_embedding=emotion_embedding,
)
model_outputs = outputs["model_outputs"] model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy() model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"] alignments = outputs["alignments"]

View File

@ -121,26 +121,42 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None: if (
self.encoder_checkpoint
and hasattr(self.tts_model, "speaker_manager")
and self.tts_model.speaker_manager is not None
):
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config) self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config)
if self.tts_emotions_file and hasattr(self.tts_model, "emotion_manager") and self.tts_model.emotion_manager is not None: if (
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)): self.tts_emotions_file
and hasattr(self.tts_model, "emotion_manager")
and self.tts_model.emotion_manager is not None
):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
):
self.tts_model.emotion_manager.load_embeddings_from_file(self.tts_emotions_file) self.tts_model.emotion_manager.load_embeddings_from_file(self.tts_emotions_file)
else: else:
self.tts_model.emotion_manager.load_ids_from_file(self.tts_emotions_file) self.tts_model.emotion_manager.load_ids_from_file(self.tts_emotions_file)
if self.tts_speakers_file and hasattr(self.tts_model, "speaker_manager") and self.tts_model.speaker_manager is not None: if (
if getattr(self.tts_config, "use_d_vector_file", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_d_vector_file", False)): self.tts_speakers_file
and hasattr(self.tts_model, "speaker_manager")
and self.tts_model.speaker_manager is not None
):
if getattr(self.tts_config, "use_d_vector_file", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_d_vector_file", False)
):
self.tts_model.speaker_manager.load_embeddings_from_file(self.tts_speakers_file) self.tts_model.speaker_manager.load_embeddings_from_file(self.tts_speakers_file)
else: else:
self.tts_model.speaker_manager.load_ids_from_file(self.tts_speakers_file) self.tts_model.speaker_manager.load_ids_from_file(self.tts_speakers_file)
def _set_speaker_encoder_paths_from_tts_config(self): def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders.""" """Set the encoder paths from the tts model config for models with speaker encoders."""
if hasattr(self.tts_config, "model_args") and hasattr( if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "encoder_config_path"):
self.tts_config.model_args, "encoder_config_path"
):
self.encoder_checkpoint = self.tts_config.model_args.encoder_model_path self.encoder_checkpoint = self.tts_config.model_args.encoder_model_path
self.encoder_config = self.tts_config.model_args.encoder_config_path self.encoder_config = self.tts_config.model_args.encoder_config_path
@ -273,11 +289,18 @@ class Synthesizer(object):
# handle emotion # handle emotion
emotion_embedding, emotion_id = None, None emotion_embedding, emotion_id = None, None
if self.tts_emotions_file or (getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)): if self.tts_emotions_file or (
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
):
if emotion_name and isinstance(emotion_name, str): if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)): if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
):
# get the average speaker embedding from the saved embeddings. # get the average speaker embedding from the saved embeddings.
emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False) emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(
emotion_name, num_samples=None, randomize=False
)
emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim] emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]
else: else:
# get speaker idx from the speaker name # get speaker idx from the speaker name