fix test vits

This commit is contained in:
WeberJulian 2021-10-29 17:09:10 +02:00 committed by Eren Gölge
parent 2a2b5767c2
commit 3b5592abcf
8 changed files with 17 additions and 32 deletions

View File

@ -261,7 +261,7 @@ class Trainer:
self.run_get_model(self.config, get_model)
if hasattr(self.model, "init_multilingual"):
self.model.init_multilingual(self.config, self.data_train + self.data_eval)
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1:

View File

@ -154,22 +154,6 @@ class VitsConfig(BaseTTSConfig):
d_vector_dim: int = None
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
if self.speaker_embedding_channels:
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file
for key in self.model_args.keys():
if hasattr(self, key):
self[key] = self.model_args[key]

View File

@ -404,8 +404,7 @@ class TTSDataset(Dataset):
# get language ids from language names
if self.language_id_mapping is not None:
language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing]
language_ids = [self.language_id_mapping[ln] for ln in language_names]
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
else:
language_ids = None
# get pre-computed d-vectors

View File

@ -245,8 +245,13 @@ class BaseTTS(BaseModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
if hasattr(config, "model_args"):
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None

View File

@ -376,8 +376,7 @@ class Vits(BaseTTS):
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
if hasattr(config, "model_args"):
config = config.model_args
config = config.model_args
self.num_speakers = config.num_speakers
@ -1033,7 +1032,6 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
for idx, s_info in enumerate(test_sentences):
try:
aux_inputs = self.get_aux_input_from_test_setences(s_info)
@ -1051,7 +1049,6 @@ class Vits(BaseTTS):
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
except: # pylint: disable=bare-except

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
@ -33,7 +33,6 @@ config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.d_vector_dim = 256

View File

@ -3,7 +3,8 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.config.shared_configs import BaseDatasetConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")