diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index de683c81..5694fe4d 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -420,6 +420,76 @@ class TestVits(unittest.TestCase): # check parameter changes self._check_parameter_changes(model, model_ref) + def test_train_step_upsampling(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + model_args = VitsArgs( + num_chars=32, + spec_segment_size=10, + encoder_sample_rate=11025, + interpolate_z=False, + upsample_rates_decoder=[8, 8, 4, 2], + ) + config = VitsConfig(model_args=model_args) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_step_upsampling_interpolation(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True) + config = VitsConfig(model_args=model_args) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + def test_train_eval_log(self): batch_size = 2 config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py deleted file mode 100644 index c279d004..00000000 --- a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py +++ /dev/null @@ -1,90 +0,0 @@ -import glob -import json -import os -import shutil - -from trainer import get_last_checkpoint - -from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs.vits_config import VitsConfig - -config_path = os.path.join(get_tests_output_path(), "test_model_config.json") -output_path = os.path.join(get_tests_output_path(), "train_outputs") - - -config = VitsConfig( - batch_size=2, - eval_batch_size=2, - num_loader_workers=0, - num_eval_loader_workers=0, - text_cleaner="english_cleaners", - use_phonemes=True, - phoneme_language="en-us", - phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", - run_eval=True, - test_delay_epochs=-1, - epochs=1, - print_step=1, - print_eval=True, - test_sentences=[ - ["Be a voice, not an echo.", "ljspeech-1"], - ], -) -# set audio config -config.audio.do_trim_silence = True -config.audio.trim_db = 60 - -# active multispeaker d-vec mode -config.model_args.use_speaker_embedding = True -config.model_args.use_d_vector_file = False -config.model_args.d_vector_file = None -config.model_args.d_vector_dim = 256 - - -# test upsample interpolation approach -config.model_args.encoder_sample_rate = 11025 -config.model_args.interpolate_z = True -config.model_args.upsample_rates_decoder = [8, 8, 2, 2] -config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7] - - -config.save_json(config_path) - -# train the model for one epoch -command_train = ( - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " - f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech_test " - "--coqpit.datasets.0.meta_file_train metadata.csv " - "--coqpit.datasets.0.meta_file_val metadata.csv " - "--coqpit.datasets.0.path tests/data/ljspeech " - "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " - "--coqpit.test_delay_epochs 0" -) -run_cli(command_train) - -# Find latest folder -continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) - -# Inference using TTS API -continue_config_path = os.path.join(continue_path, "config.json") -continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), "output.wav") -speaker_id = "ljspeech-1" -continue_speakers_path = os.path.join(continue_path, "speakers.json") - -# Check integrity of the config -with open(continue_config_path, "r", encoding="utf-8") as f: - config_loaded = json.load(f) -assert config_loaded["characters"] is not None -assert config_loaded["output_path"] in continue_path -assert config_loaded["test_delay_epochs"] == 0 - -# Load the model and run inference -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" -run_cli(inference_command) - -# restore the model and continue training for one more epoch -command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " -run_cli(command_train) -shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py deleted file mode 100644 index 35248b4c..00000000 --- a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py +++ /dev/null @@ -1,90 +0,0 @@ -import glob -import json -import os -import shutil - -from trainer import get_last_checkpoint - -from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs.vits_config import VitsConfig - -config_path = os.path.join(get_tests_output_path(), "test_model_config.json") -output_path = os.path.join(get_tests_output_path(), "train_outputs") - - -config = VitsConfig( - batch_size=2, - eval_batch_size=2, - num_loader_workers=0, - num_eval_loader_workers=0, - text_cleaner="english_cleaners", - use_phonemes=True, - phoneme_language="en-us", - phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", - run_eval=True, - test_delay_epochs=-1, - epochs=1, - print_step=1, - print_eval=True, - test_sentences=[ - ["Be a voice, not an echo.", "ljspeech-1"], - ], -) -# set audio config -config.audio.do_trim_silence = True -config.audio.trim_db = 60 - -# active multispeaker d-vec mode -config.model_args.use_speaker_embedding = True -config.model_args.use_d_vector_file = False -config.model_args.d_vector_file = None -config.model_args.d_vector_dim = 256 - - -# test upsample -config.model_args.encoder_sample_rate = 11025 -config.model_args.interpolate_z = False -config.model_args.upsample_rates_decoder = [8, 8, 4, 2] -config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7] - - -config.save_json(config_path) - -# train the model for one epoch -command_train = ( - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " - f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech_test " - "--coqpit.datasets.0.meta_file_train metadata.csv " - "--coqpit.datasets.0.meta_file_val metadata.csv " - "--coqpit.datasets.0.path tests/data/ljspeech " - "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " - "--coqpit.test_delay_epochs 0" -) -run_cli(command_train) - -# Find latest folder -continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) - -# Inference using TTS API -continue_config_path = os.path.join(continue_path, "config.json") -continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), "output.wav") -speaker_id = "ljspeech-1" -continue_speakers_path = os.path.join(continue_path, "speakers.json") - -# Check integrity of the config -with open(continue_config_path, "r", encoding="utf-8") as f: - config_loaded = json.load(f) -assert config_loaded["characters"] is not None -assert config_loaded["output_path"] in continue_path -assert config_loaded["test_delay_epochs"] == 0 - -# Load the model and run inference -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" -run_cli(inference_command) - -# restore the model and continue training for one more epoch -command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " -run_cli(command_train) -shutil.rmtree(continue_path)