diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 8dcde6bd..8061feec 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -632,7 +632,10 @@ class Vits(BaseTTS): ) if self.args.init_discriminator: - self.disc = VitsDiscriminator(periods=self.args.periods_multi_period_discriminator, use_spectral_norm=self.args.use_spectral_norm_disriminator) + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) if self.args.TTS_part_sample_rate: self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 7b9b6335..683bb0a7 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -71,14 +71,6 @@ config.use_sdp = False # active language sampler config.use_language_weighted_sampler = True -# test upsample -config.model_args.TTS_part_sample_rate = 11025 -config.model_args.interpolate_z = False -config.model_args.detach_z_vocoder = True - -config.model_args.upsample_rates_decoder = [8, 8, 4, 2] -config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7, 11, 13, 17, 19, 23] - config.save_json(config_path) # train the model for one epoch diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py new file mode 100644 index 00000000..9d9e372c --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py @@ -0,0 +1,90 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = None +config.model_args.d_vector_dim = 256 + + +# test upsample interpolation approach +config.model_args.TTS_part_sample_rate = 11025 +config.model_args.interpolate_z = True +config.model_args.upsample_rates_decoder = [8, 8, 2, 2] +config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7] + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py new file mode 100644 index 00000000..758aa4a1 --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py @@ -0,0 +1,90 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = None +config.model_args.d_vector_dim = 256 + + +# test upsample +config.model_args.TTS_part_sample_rate = 11025 +config.model_args.interpolate_z = False +config.model_args.upsample_rates_decoder = [8, 8, 4, 2] +config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7] + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)