mirror of https://github.com/coqui-ai/TTS.git
fix test vits
This commit is contained in:
parent
2a2b5767c2
commit
3b5592abcf
|
@ -261,7 +261,7 @@ class Trainer:
|
|||
self.run_get_model(self.config, get_model)
|
||||
|
||||
if hasattr(self.model, "init_multilingual"):
|
||||
self.model.init_multilingual(self.config, self.data_train + self.data_eval)
|
||||
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
|
||||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||
# save speakers json
|
||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||
|
|
|
@ -154,22 +154,6 @@ class VitsConfig(BaseTTSConfig):
|
|||
d_vector_dim: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||
if self.num_speakers > 0:
|
||||
self.model_args.num_speakers = self.num_speakers
|
||||
|
||||
# speaker embedding settings
|
||||
if self.use_speaker_embedding:
|
||||
self.model_args.use_speaker_embedding = True
|
||||
if self.speakers_file:
|
||||
self.model_args.speakers_file = self.speakers_file
|
||||
if self.speaker_embedding_channels:
|
||||
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
|
||||
|
||||
# d-vector settings
|
||||
if self.use_d_vector_file:
|
||||
self.model_args.use_d_vector_file = True
|
||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||
self.model_args.d_vector_dim = self.d_vector_dim
|
||||
if self.d_vector_file:
|
||||
self.model_args.d_vector_file = self.d_vector_file
|
||||
for key in self.model_args.keys():
|
||||
if hasattr(self, key):
|
||||
self[key] = self.model_args[key]
|
||||
|
|
|
@ -404,8 +404,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
# get language ids from language names
|
||||
if self.language_id_mapping is not None:
|
||||
language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing]
|
||||
language_ids = [self.language_id_mapping[ln] for ln in language_names]
|
||||
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
|
||||
else:
|
||||
language_ids = None
|
||||
# get pre-computed d-vectors
|
||||
|
|
|
@ -245,8 +245,13 @@ class BaseTTS(BaseModel):
|
|||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
|
|
@ -376,8 +376,7 @@ class Vits(BaseTTS):
|
|||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
config = config.model_args
|
||||
|
||||
self.num_speakers = config.num_speakers
|
||||
|
||||
|
@ -1033,7 +1032,6 @@ class Vits(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
||||
for idx, s_info in enumerate(test_sentences):
|
||||
try:
|
||||
aux_inputs = self.get_aux_input_from_test_setences(s_info)
|
||||
|
@ -1051,7 +1049,6 @@ class Vits(BaseTTS):
|
|||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import VitsConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
@ -33,7 +33,6 @@ config.audio.do_trim_silence = True
|
|||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
|
|
@ -3,7 +3,8 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import VitsConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
Loading…
Reference in New Issue