Change TrainingArgs -> TrainerArgs

This commit is contained in:
Eren Gölge 2022-02-03 15:42:12 +01:00
parent aa81454721
commit 27db089d6c
24 changed files with 50 additions and 51 deletions

View File

@ -1,7 +1,7 @@
import os
from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
@ -9,7 +9,7 @@ from TTS.tts.models import setup_model
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainingArgs()
train_args = TrainerArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args

View File

@ -1,7 +1,7 @@
import os
from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model
@ -10,7 +10,7 @@ from TTS.vocoder.models import setup_model
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainingArgs()
train_args = TrainerArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.align_tts_config import AlignTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
@ -57,7 +57,7 @@ model = AlignTTS(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_pitch_config import FastPitchConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -90,6 +90,6 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -1,7 +1,7 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_speech_config import FastSpeechConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -89,6 +89,6 @@ model = ForwardTTS(config, ap, tokenizer)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -2,7 +2,7 @@ import os
# Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer.
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
@ -72,7 +72,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
@ -40,7 +40,7 @@ model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import MultibandMelganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
@ -40,7 +40,7 @@ model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,7 +1,7 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -75,7 +75,7 @@ model = ForwardTTS(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
@ -88,7 +88,7 @@ model = Tacotron2(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
@ -83,7 +83,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import UnivnetConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
@ -39,7 +39,7 @@ model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
@ -33,7 +33,7 @@ audio_config = BaseAudioConfig(
config = VitsConfig(
audio=audio_config,
run_name="vits_ljspeech",
batch_size=16,
batch_size=32,
eval_batch_size=16,
batch_group_size=5,
num_loader_workers=0,
@ -48,8 +48,7 @@ config = VitsConfig(
compute_input_seq_cache=True,
print_step=25,
print_eval=True,
mixed_precision=False,
max_seq_len=500000,
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
)
@ -76,7 +75,7 @@ model = Vits(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavegradConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
@ -37,7 +37,7 @@ model = Wavegrad(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,6 +1,6 @@
import os
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavernnConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
@ -39,7 +39,7 @@ model = Wavernn(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -2,7 +2,7 @@ import os
from glob import glob
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
@ -119,7 +119,7 @@ model = Vits(config, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,

View File

@ -1,7 +1,7 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_pitch_config import FastPitchConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -85,7 +85,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_speech_config import FastSpeechConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
@ -83,7 +83,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.datasets import load_tts_samples
@ -85,7 +85,7 @@ model = Tacotron(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
@ -90,7 +90,7 @@ model = Vits(config, ap, tokenizer, speaker_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
TrainerArgs(),
config,
output_path,
model=model,