Update imports for trainer

This commit is contained in:
Eren Gölge 2022-02-20 11:35:42 +01:00
parent c911729896
commit be3a03126a
6 changed files with 34 additions and 13 deletions

View File

@ -8,14 +8,14 @@ import time
import torch
from TTS.trainer import TrainingArgs
from trainer import TrainerArgs
def main():
"""
Call train.py as a new process and pass command arguments
"""
parser = TrainingArgs().init_argparse(arg_prefix="")
parser = TrainerArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.")
args, unargs = parser.parse_known_args()

View File

@ -9,6 +9,8 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
@ -19,7 +21,7 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass, field
import os
from trainer import Trainer, TrainerArgs
@ -7,10 +8,15 @@ from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
@dataclass
class TrainTTSArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainerArgs()
train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass, field
import os
from trainer import Trainer, TrainerArgs
@ -8,10 +9,15 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model
@dataclass
class TrainVocoderArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainerArgs()
train_args = TrainVocoderArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args

View File

@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custome model paths
# CASE3: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
@ -170,9 +170,9 @@ def tts():
text = request.args.get("text")
speaker_idx = request.args.get("speaker_id", "")
style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)

View File

@ -1,19 +1,26 @@
from asyncio.log import logger
from dataclasses import dataclass, field
import os
from coqpit import Coqpit
from TTS.config import load_config, register_config
from TTS.trainer import TrainingArgs
from trainer import TrainerArgs
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
from TTS.utils.logging import init_dashboard_logger
from TTS.utils.logging.console_logger import ConsoleLogger
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.utils.trainer_utils import get_last_checkpoint
@dataclass
class TrainArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def getarguments():
train_config = TrainingArgs()
train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
@ -75,13 +82,13 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config)
dashboard_logger = logger_factory(config, experiment_path)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments():
train_config = TrainingArgs()
train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser