From be3a03126ad4b2630c6006019dc28e3bcea6c994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:35:42 +0100 Subject: [PATCH] Update imports for trainer --- TTS/bin/distribute.py | 4 ++-- TTS/bin/train_encoder.py | 4 +++- TTS/bin/train_tts.py | 8 +++++++- TTS/bin/train_vocoder.py | 8 +++++++- TTS/server/server.py | 4 ++-- TTS/speaker_encoder/utils/training.py | 19 +++++++++++++------ 6 files changed, 34 insertions(+), 13 deletions(-) diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 06d5f388..40f60d5d 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -8,14 +8,14 @@ import time import torch -from TTS.trainer import TrainingArgs +from trainer import TrainerArgs def main(): """ Call train.py as a new process and pass command arguments """ - parser = TrainingArgs().init_argparse(arg_prefix="") + parser = TrainerArgs().init_argparse(arg_prefix="") parser.add_argument("--script", type=str, help="Target training script to distibute.") args, unargs = parser.parse_known_args() diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 8c364300..f19966ee 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -9,6 +9,8 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.torch import NoamLR + from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model @@ -19,7 +21,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.io import load_fsspec from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, check_update +from TTS.utils.training import check_update torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 73063731..467685b2 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from trainer import Trainer, TrainerArgs @@ -7,10 +8,15 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainerArgs() + train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 6d4df610..c52fd962 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from trainer import Trainer, TrainerArgs @@ -8,10 +9,15 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainerArgs() + train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/server/server.py b/TTS/server/server.py index f2512582..aef507fd 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path: if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) -# CASE3: set custome model paths +# CASE3: set custom model paths if args.model_path is not None: model_path = args.model_path config_path = args.config_path @@ -170,9 +170,9 @@ def tts(): text = request.args.get("text") speaker_idx = request.args.get("speaker_id", "") style_wav = request.args.get("style_wav", "") - style_wav = style_wav_uri_to_dict(style_wav) print(" > Model input: {}".format(text)) + print(" > Speaker Idx: {}".format(speaker_idx)) wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index b202ebcd..5c2de274 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,19 +1,26 @@ +from asyncio.log import logger +from dataclasses import dataclass, field import os from coqpit import Coqpit from TTS.config import load_config, register_config -from TTS.trainer import TrainingArgs +from trainer import TrainerArgs from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from TTS.utils.logging import init_dashboard_logger -from TTS.utils.logging.console_logger import ConsoleLogger +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger from TTS.utils.trainer_utils import get_last_checkpoint +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def getarguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser @@ -75,13 +82,13 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - dashboard_logger = init_dashboard_logger(config) + dashboard_logger = logger_factory(config, experiment_path) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger def init_arguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser