diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 4a4f86c4..85c22673 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,12 +1,12 @@ import os -from TTS.tts.configs import AlignTTSConfig -from TTS.tts.configs import BaseDatasetConfig -from TTS.trainer import init_training, Trainer, TrainingArgs - +from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig output_path = os.path.dirname(os.path.abspath(__file__)) -dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")) +dataset_config = BaseDatasetConfig( + name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") +) config = AlignTTSConfig( batch_size=32, eval_batch_size=16, @@ -23,7 +23,7 @@ config = AlignTTSConfig( print_eval=True, mixed_precision=False, output_path=output_path, - datasets=[dataset_config] + datasets=[dataset_config], ) args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) trainer = Trainer(args, config, output_path, c_logger, tb_logger) diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 0a3c3838..f77997e8 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -1,12 +1,12 @@ import os -from TTS.tts.configs import GlowTTSConfig -from TTS.tts.configs import BaseDatasetConfig -from TTS.trainer import init_training, Trainer, TrainingArgs - +from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig output_path = os.path.dirname(os.path.abspath(__file__)) -dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")) +dataset_config = BaseDatasetConfig( + name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") +) config = GlowTTSConfig( batch_size=32, eval_batch_size=16, @@ -23,7 +23,7 @@ config = GlowTTSConfig( print_eval=True, mixed_precision=False, output_path=output_path, - datasets=[dataset_config] + datasets=[dataset_config], ) args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) trainer = Trainer(args, config, output_path, c_logger, tb_logger) diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 6b766ab7..9d1b9e6f 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,8 +1,7 @@ import os +from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.vocoder.configs import MultibandMelganConfig -from TTS.trainer import init_training, Trainer, TrainingArgs - output_path = os.path.dirname(os.path.abspath(__file__)) config = MultibandMelganConfig( diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index a442b451..d8f33ae3 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,6 +1,5 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.vocoder.configs import UnivnetConfig diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 323b2bb7..4f82f50b 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,10 +1,8 @@ import os -from TTS.trainer import Trainer, init_training -from TTS.trainer import TrainingArgs +from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.vocoder.configs import WavegradConfig - output_path = os.path.dirname(os.path.abspath(__file__)) config = WavegradConfig( batch_size=32, diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 76ff722a..7222f78a 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,9 +1,8 @@ import os -from TTS.trainer import Trainer, init_training, TrainingArgs +from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.vocoder.configs import WavernnConfig - output_path = os.path.dirname(os.path.abspath(__file__)) config = WavernnConfig( batch_size=64,