Change TrainingArgs -> TrainerArgs

This commit is contained in:
Eren Gölge 2022-02-03 15:42:12 +01:00
parent aa81454721
commit 27db089d6c
24 changed files with 50 additions and 51 deletions

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
@ -9,7 +9,7 @@ from TTS.tts.models import setup_model
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
# init trainer args # init trainer args
train_args = TrainingArgs() train_args = TrainerArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args # override trainer args from comman-line args

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model from TTS.vocoder.models import setup_model
@ -10,7 +10,7 @@ from TTS.vocoder.models import setup_model
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
# init trainer args # init trainer args
train_args = TrainingArgs() train_args = TrainerArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args # override trainer args from comman-line args

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.align_tts_config import AlignTTSConfig from TTS.tts.configs.align_tts_config import AlignTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -57,7 +57,7 @@ model = AlignTTS(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -90,6 +90,6 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.configs.fast_speech_config import FastSpeechConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -89,6 +89,6 @@ model = ForwardTTS(config, ap, tokenizer)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -2,7 +2,7 @@ import os
# Trainer: Where the ✨️ happens. # Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer. # TrainingArgs: Defines the set of arguments of the Trainer.
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
# GlowTTSConfig: all model related values for training, validating and testing. # GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
@ -72,7 +72,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.configs import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
@ -40,7 +40,7 @@ model = GAN(config)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.configs import MultibandMelganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
@ -40,7 +40,7 @@ model = GAN(config)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -75,7 +75,7 @@ model = ForwardTTS(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -88,7 +88,7 @@ model = Tacotron2(config, ap, tokenizer)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -83,7 +83,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.configs import UnivnetConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
@ -39,7 +39,7 @@ model = GAN(config)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -33,7 +33,7 @@ audio_config = BaseAudioConfig(
config = VitsConfig( config = VitsConfig(
audio=audio_config, audio=audio_config,
run_name="vits_ljspeech", run_name="vits_ljspeech",
batch_size=16, batch_size=32,
eval_batch_size=16, eval_batch_size=16,
batch_group_size=5, batch_group_size=5,
num_loader_workers=0, num_loader_workers=0,
@ -48,8 +48,7 @@ config = VitsConfig(
compute_input_seq_cache=True, compute_input_seq_cache=True,
print_step=25, print_step=25,
print_eval=True, print_eval=True,
mixed_precision=False, mixed_precision=True,
max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )
@ -76,7 +75,7 @@ model = Vits(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.configs import WavegradConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
@ -37,7 +37,7 @@ model = Wavegrad(config)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,6 +1,6 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.configs import WavernnConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
@ -39,7 +39,7 @@ model = Wavernn(config)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -2,7 +2,7 @@ import os
from glob import glob from glob import glob
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -119,7 +119,7 @@ model = Vits(config, speaker_manager, language_manager)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -85,7 +85,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.configs.fast_speech_config import FastSpeechConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -83,7 +83,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.models.forward_tts import ForwardTTS
@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -85,7 +85,7 @@ model = Tacotron(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀

View File

@ -1,7 +1,7 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -90,7 +90,7 @@ model = Vits(config, ap, tokenizer, speaker_manager)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainerArgs(),
config, config,
output_path, output_path,
model=model, model=model,