diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py index 1b777803..59e90e0a 100644 --- a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -7,7 +7,7 @@ from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.utils.trainer_utils import get_last_checkpoint -config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( @@ -45,6 +45,8 @@ config = FastPitchConfig( ], ) config.audio.do_trim_silence = True +config.use_speaker_embedding = True +config.model_args.use_speaker_embedding = True config.audio.trim_db = 60 config.save_json(config_path) diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 9aae5bbd..bbfbb823 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -7,7 +7,7 @@ from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.utils.trainer_utils import get_last_checkpoint -config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( @@ -42,8 +42,11 @@ config = FastPitchConfig( test_sentences=[ "Be a voice, not an echo.", ], + use_speaker_embedding=False, ) config.audio.do_trim_silence = True +config.use_speaker_embedding = False +config.model_args.use_speaker_embedding = False config.audio.trim_db = 60 config.save_json(config_path) @@ -58,6 +61,7 @@ command_train = ( "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " "--coqpit.test_delay_epochs 0" ) + run_cli(command_train) # Find latest folder