From 27db089d6c27fb5a58513abd73c5451fac981a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:42:12 +0100 Subject: [PATCH] Change TrainingArgs -> TrainerArgs --- TTS/bin/train_tts.py | 4 ++-- TTS/bin/train_vocoder.py | 4 ++-- recipes/ljspeech/align_tts/train_aligntts.py | 4 ++-- recipes/ljspeech/fast_pitch/train_fast_pitch.py | 4 ++-- recipes/ljspeech/fast_speech/train_fast_speech.py | 4 ++-- recipes/ljspeech/glow_tts/train_glowtts.py | 4 ++-- recipes/ljspeech/hifigan/train_hifigan.py | 4 ++-- .../ljspeech/multiband_melgan/train_multiband_melgan.py | 4 ++-- recipes/ljspeech/speedy_speech/train_speedy_speech.py | 4 ++-- recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py | 4 ++-- recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py | 4 ++-- recipes/ljspeech/univnet/train.py | 4 ++-- recipes/ljspeech/vits_tts/train_vits.py | 9 ++++----- recipes/ljspeech/wavegrad/train_wavegrad.py | 4 ++-- recipes/ljspeech/wavernn/train_wavernn.py | 4 ++-- recipes/multilingual/vits_tts/train_vits_tts.py | 4 ++-- recipes/vctk/fast_pitch/train_fast_pitch.py | 4 ++-- recipes/vctk/fast_speech/train_fast_speech.py | 4 ++-- recipes/vctk/glow_tts/train_glow_tts.py | 4 ++-- recipes/vctk/speedy_speech/train_speedy_speech.py | 4 ++-- recipes/vctk/tacotron-DDC/train_tacotron-DDC.py | 4 ++-- recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py | 4 ++-- recipes/vctk/tacotron2/train_tacotron2.py | 4 ++-- recipes/vctk/vits/train_vits.py | 4 ++-- 24 files changed, 50 insertions(+), 51 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 824f0128..79b78767 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,7 +1,7 @@ import os from TTS.config import load_config, register_config -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model @@ -9,7 +9,7 @@ from TTS.tts.models import setup_model def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainerArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index cd665f29..081fdd56 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,7 +1,7 @@ import os from TTS.config import load_config, register_config -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model @@ -10,7 +10,7 @@ from TTS.vocoder.models import setup_model def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainerArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index d0187aa8..a4b868aa 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.align_tts_config import AlignTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples @@ -57,7 +57,7 @@ model = AlignTTS(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 3a772251..fcb62282 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -90,6 +90,6 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index f9f1bc06..183c8ebb 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -89,6 +89,6 @@ model = ForwardTTS(config, ap, tokenizer) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index dd450a57..c47cd00a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -2,7 +2,7 @@ import os # Trainer: Where the ✨️ happens. # TrainingArgs: Defines the set of arguments of the Trainer. -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs # GlowTTSConfig: all model related values for training, validating and testing. from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -72,7 +72,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index 8d1c272a..964a6420 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +40,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 90c52997..6f528a83 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +40,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 2f8896c5..6a9ddf16 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -75,7 +75,7 @@ model = ForwardTTS(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index a7f037e6..c3a1c51c 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -88,7 +88,7 @@ model = Tacotron2(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index 285c416c..a7482b32 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -83,7 +83,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 589fd027..35240c5b 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +39,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 79c0db2e..24ff4d0f 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -33,7 +33,7 @@ audio_config = BaseAudioConfig( config = VitsConfig( audio=audio_config, run_name="vits_ljspeech", - batch_size=16, + batch_size=32, eval_batch_size=16, batch_group_size=5, num_loader_workers=0, @@ -48,8 +48,7 @@ config = VitsConfig( compute_input_seq_cache=True, print_step=25, print_eval=True, - mixed_precision=False, - max_seq_len=500000, + mixed_precision=True, output_path=output_path, datasets=[dataset_config], ) @@ -76,7 +75,7 @@ model = Vits(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 6786c052..095773d6 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -37,7 +37,7 @@ model = Wavegrad(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index f64f5752..172b489a 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +39,7 @@ model = Wavernn(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index be4747df..391f31cb 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -2,7 +2,7 @@ import os from glob import glob from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -119,7 +119,7 @@ model = Vits(config, speaker_manager, language_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index 4d9cc10d..aeb62055 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -85,7 +85,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 1dcab982..578fbd1a 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index e35e552d..0f198a86 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples @@ -83,7 +83,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index 85e347fc..fbb1af2d 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index 7960b34b..917c5588 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.datasets import load_tts_samples @@ -85,7 +85,7 @@ model = Tacotron(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index bc7951b5..759ddd57 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 82dedade..0c62da48 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index caf1caa1..53d7242c 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -90,7 +90,7 @@ model = Vits(config, ap, tokenizer, speaker_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model,