mirror of https://github.com/coqui-ai/TTS.git
Update imports for trainer
This commit is contained in:
parent
c911729896
commit
be3a03126a
|
@ -8,14 +8,14 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.trainer import TrainingArgs
|
from trainer import TrainerArgs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
Call train.py as a new process and pass command arguments
|
Call train.py as a new process and pass command arguments
|
||||||
"""
|
"""
|
||||||
parser = TrainingArgs().init_argparse(arg_prefix="")
|
parser = TrainerArgs().init_argparse(arg_prefix="")
|
||||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||||
args, unargs = parser.parse_known_args()
|
args, unargs = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ import traceback
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from trainer.torch import NoamLR
|
||||||
|
|
||||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||||
|
@ -19,7 +21,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.training import NoamLR, check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
@ -7,10 +8,15 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainTTSArgs(TrainerArgs):
|
||||||
|
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run `tts` model training directly by a `config.json` file."""
|
"""Run `tts` model training directly by a `config.json` file."""
|
||||||
# init trainer args
|
# init trainer args
|
||||||
train_args = TrainerArgs()
|
train_args = TrainTTSArgs()
|
||||||
parser = train_args.init_argparse(arg_prefix="")
|
parser = train_args.init_argparse(arg_prefix="")
|
||||||
|
|
||||||
# override trainer args from comman-line args
|
# override trainer args from comman-line args
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
@ -8,10 +9,15 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model
|
from TTS.vocoder.models import setup_model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainVocoderArgs(TrainerArgs):
|
||||||
|
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run `tts` model training directly by a `config.json` file."""
|
"""Run `tts` model training directly by a `config.json` file."""
|
||||||
# init trainer args
|
# init trainer args
|
||||||
train_args = TrainerArgs()
|
train_args = TrainVocoderArgs()
|
||||||
parser = train_args.init_argparse(arg_prefix="")
|
parser = train_args.init_argparse(arg_prefix="")
|
||||||
|
|
||||||
# override trainer args from comman-line args
|
# override trainer args from comman-line args
|
||||||
|
|
|
@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
|
||||||
if args.vocoder_name is not None and not args.vocoder_path:
|
if args.vocoder_name is not None and not args.vocoder_path:
|
||||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||||
|
|
||||||
# CASE3: set custome model paths
|
# CASE3: set custom model paths
|
||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
config_path = args.config_path
|
config_path = args.config_path
|
||||||
|
@ -170,9 +170,9 @@ def tts():
|
||||||
text = request.args.get("text")
|
text = request.args.get("text")
|
||||||
speaker_idx = request.args.get("speaker_id", "")
|
speaker_idx = request.args.get("speaker_id", "")
|
||||||
style_wav = request.args.get("style_wav", "")
|
style_wav = request.args.get("style_wav", "")
|
||||||
|
|
||||||
style_wav = style_wav_uri_to_dict(style_wav)
|
style_wav = style_wav_uri_to_dict(style_wav)
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
|
print(" > Speaker Idx: {}".format(speaker_idx))
|
||||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
||||||
out = io.BytesIO()
|
out = io.BytesIO()
|
||||||
synthesizer.save_wav(wavs, out)
|
synthesizer.save_wav(wavs, out)
|
||||||
|
|
|
@ -1,19 +1,26 @@
|
||||||
|
from asyncio.log import logger
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.trainer import TrainingArgs
|
from trainer import TrainerArgs
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
from TTS.utils.logging import init_dashboard_logger
|
from trainer.logging import logger_factory
|
||||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
from TTS.utils.trainer_utils import get_last_checkpoint
|
from TTS.utils.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainArgs(TrainerArgs):
|
||||||
|
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||||
|
|
||||||
|
|
||||||
def getarguments():
|
def getarguments():
|
||||||
train_config = TrainingArgs()
|
train_config = TrainArgs()
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -75,13 +82,13 @@ def process_args(args, config=None):
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
dashboard_logger = init_dashboard_logger(config)
|
dashboard_logger = logger_factory(config, experiment_path)
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||||
|
|
||||||
|
|
||||||
def init_arguments():
|
def init_arguments():
|
||||||
train_config = TrainingArgs()
|
train_config = TrainArgs()
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue