From e6f45b9eb712d3ac7f523552b453e8bda104880d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 May 2021 03:39:49 +0200 Subject: [PATCH] update train_vocoder_gan.py for coqpit --- TTS/bin/train_vocoder_gan.py | 13 +++++++------ .../test_vocoder_gan_datasets.py | 0 tests/{ => vocoder_tests}/test_vocoder_losses.py | 0 .../test_vocoder_melgan_discriminator.py | 0 .../test_vocoder_melgan_generator.py | 0 .../test_vocoder_parallel_wavegan_discriminator.py | 0 .../test_vocoder_parallel_wavegan_generator.py | 0 tests/{ => vocoder_tests}/test_vocoder_pqmf.py | 0 tests/{ => vocoder_tests}/test_vocoder_rwd.py | 0 .../test_vocoder_tf_melgan_generator.py | 0 tests/{ => vocoder_tests}/test_vocoder_tf_pqmf.py | 0 tests/{ => vocoder_tests}/test_vocoder_wavernn.py | 0 .../test_vocoder_wavernn_datasets.py | 0 .../test_wavegrad.py} | 0 tests/{ => vocoder_tests}/test_wavegrad_layers.py | 0 15 files changed, 7 insertions(+), 6 deletions(-) rename tests/{ => vocoder_tests}/test_vocoder_gan_datasets.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_losses.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_melgan_discriminator.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_melgan_generator.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_parallel_wavegan_discriminator.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_parallel_wavegan_generator.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_pqmf.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_rwd.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_tf_melgan_generator.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_tf_pqmf.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_wavernn.py (100%) rename tests/{ => vocoder_tests}/test_vocoder_wavernn_datasets.py (100%) rename tests/{test_wavegrad_train.py => vocoder_tests/test_wavegrad.py} (100%) rename tests/{ => vocoder_tests}/test_wavegrad_layers.py (100%) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index f33df3e8..4159f12f 100755 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.utils.arguments import parse_arguments, process_args +from TTS.utils.arguments import init_training from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict @@ -163,7 +163,6 @@ def train( y_hat_sub=y_hat_sub, y_sub=y_G_sub, ) - loss_G = loss_G_dict["G_loss"] # optimizer generator @@ -469,7 +468,7 @@ def main(args): # pylint: disable=redefined-outer-name eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor - ap = AudioProcessor(**c.audio) + ap = AudioProcessor(**c.audio.to_dict()) # DISTRUBUTED if num_gpus > 1: @@ -620,13 +619,15 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") - + args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) try: main(args) except KeyboardInterrupt: remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access except Exception: # pylint: disable=broad-except remove_experiment_folder(OUT_PATH) traceback.print_exc() diff --git a/tests/test_vocoder_gan_datasets.py b/tests/vocoder_tests/test_vocoder_gan_datasets.py similarity index 100% rename from tests/test_vocoder_gan_datasets.py rename to tests/vocoder_tests/test_vocoder_gan_datasets.py diff --git a/tests/test_vocoder_losses.py b/tests/vocoder_tests/test_vocoder_losses.py similarity index 100% rename from tests/test_vocoder_losses.py rename to tests/vocoder_tests/test_vocoder_losses.py diff --git a/tests/test_vocoder_melgan_discriminator.py b/tests/vocoder_tests/test_vocoder_melgan_discriminator.py similarity index 100% rename from tests/test_vocoder_melgan_discriminator.py rename to tests/vocoder_tests/test_vocoder_melgan_discriminator.py diff --git a/tests/test_vocoder_melgan_generator.py b/tests/vocoder_tests/test_vocoder_melgan_generator.py similarity index 100% rename from tests/test_vocoder_melgan_generator.py rename to tests/vocoder_tests/test_vocoder_melgan_generator.py diff --git a/tests/test_vocoder_parallel_wavegan_discriminator.py b/tests/vocoder_tests/test_vocoder_parallel_wavegan_discriminator.py similarity index 100% rename from tests/test_vocoder_parallel_wavegan_discriminator.py rename to tests/vocoder_tests/test_vocoder_parallel_wavegan_discriminator.py diff --git a/tests/test_vocoder_parallel_wavegan_generator.py b/tests/vocoder_tests/test_vocoder_parallel_wavegan_generator.py similarity index 100% rename from tests/test_vocoder_parallel_wavegan_generator.py rename to tests/vocoder_tests/test_vocoder_parallel_wavegan_generator.py diff --git a/tests/test_vocoder_pqmf.py b/tests/vocoder_tests/test_vocoder_pqmf.py similarity index 100% rename from tests/test_vocoder_pqmf.py rename to tests/vocoder_tests/test_vocoder_pqmf.py diff --git a/tests/test_vocoder_rwd.py b/tests/vocoder_tests/test_vocoder_rwd.py similarity index 100% rename from tests/test_vocoder_rwd.py rename to tests/vocoder_tests/test_vocoder_rwd.py diff --git a/tests/test_vocoder_tf_melgan_generator.py b/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py similarity index 100% rename from tests/test_vocoder_tf_melgan_generator.py rename to tests/vocoder_tests/test_vocoder_tf_melgan_generator.py diff --git a/tests/test_vocoder_tf_pqmf.py b/tests/vocoder_tests/test_vocoder_tf_pqmf.py similarity index 100% rename from tests/test_vocoder_tf_pqmf.py rename to tests/vocoder_tests/test_vocoder_tf_pqmf.py diff --git a/tests/test_vocoder_wavernn.py b/tests/vocoder_tests/test_vocoder_wavernn.py similarity index 100% rename from tests/test_vocoder_wavernn.py rename to tests/vocoder_tests/test_vocoder_wavernn.py diff --git a/tests/test_vocoder_wavernn_datasets.py b/tests/vocoder_tests/test_vocoder_wavernn_datasets.py similarity index 100% rename from tests/test_vocoder_wavernn_datasets.py rename to tests/vocoder_tests/test_vocoder_wavernn_datasets.py diff --git a/tests/test_wavegrad_train.py b/tests/vocoder_tests/test_wavegrad.py similarity index 100% rename from tests/test_wavegrad_train.py rename to tests/vocoder_tests/test_wavegrad.py diff --git a/tests/test_wavegrad_layers.py b/tests/vocoder_tests/test_wavegrad_layers.py similarity index 100% rename from tests/test_wavegrad_layers.py rename to tests/vocoder_tests/test_wavegrad_layers.py