mirror of https://github.com/coqui-ai/TTS.git
Update imports for trainer
This commit is contained in:
parent
c911729896
commit
be3a03126a
|
@ -8,14 +8,14 @@ import time
|
|||
|
||||
import torch
|
||||
|
||||
from TTS.trainer import TrainingArgs
|
||||
from trainer import TrainerArgs
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = TrainingArgs().init_argparse(arg_prefix="")
|
||||
parser = TrainerArgs().init_argparse(arg_prefix="")
|
||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||
args, unargs = parser.parse_known_args()
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ import traceback
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
|
@ -19,7 +21,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, check_update
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
@ -7,10 +8,15 @@ from TTS.tts.datasets import load_tts_samples
|
|||
from TTS.tts.models import setup_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainTTSArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainerArgs()
|
||||
train_args = TrainTTSArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
@ -8,10 +9,15 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
|||
from TTS.vocoder.models import setup_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainVocoderArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainerArgs()
|
||||
train_args = TrainVocoderArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
|
|
|
@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
|
|||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
# CASE3: set custome model paths
|
||||
# CASE3: set custom model paths
|
||||
if args.model_path is not None:
|
||||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
|
@ -170,9 +170,9 @@ def tts():
|
|||
text = request.args.get("text")
|
||||
speaker_idx = request.args.get("speaker_id", "")
|
||||
style_wav = request.args.get("style_wav", "")
|
||||
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
print(" > Model input: {}".format(text))
|
||||
print(" > Speaker Idx: {}".format(speaker_idx))
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
|
|
@ -1,19 +1,26 @@
|
|||
from asyncio.log import logger
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import TrainingArgs
|
||||
from trainer import TrainerArgs
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.logging import init_dashboard_logger
|
||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
@ -75,13 +82,13 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
dashboard_logger = logger_factory(config, experiment_path)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
|
Loading…
Reference in New Issue