mirror of https://github.com/coqui-ai/TTS.git
save default model chars to the training config file
This commit is contained in:
parent
62a8eba3b2
commit
194f82de51
|
@ -3,17 +3,18 @@
|
||||||
"""Argument parser for training scripts."""
|
"""Argument parser for training scripts."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from TTS.utils.generic_utils import (
|
import json
|
||||||
create_experiment_folder, get_git_branch)
|
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import check_config_tts
|
from TTS.tts.utils.generic_utils import check_config_tts
|
||||||
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||||
|
from TTS.utils.io import (copy_model_files, load_config,
|
||||||
|
save_characters_to_config)
|
||||||
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
|
from TTS.tts.utils.text.symbols import parse_symbols
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments(argv):
|
def parse_arguments(argv):
|
||||||
|
@ -110,12 +111,9 @@ def get_last_checkpoint(path):
|
||||||
def process_args(args, model_type):
|
def process_args(args, model_type):
|
||||||
"""Process parsed comand line arguments.
|
"""Process parsed comand line arguments.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||||
args : argparse.Namespace or dict like
|
model_type (str): Model type used to check config parameters and setup the TensorBoard
|
||||||
Parsed input arguments.
|
|
||||||
model_type : str
|
|
||||||
Model type used to check config parameters and setup the TensorBoard
|
|
||||||
logger. One of:
|
logger. One of:
|
||||||
- tacotron
|
- tacotron
|
||||||
- glow_tts
|
- glow_tts
|
||||||
|
@ -124,24 +122,16 @@ def process_args(args, model_type):
|
||||||
- wavegrad
|
- wavegrad
|
||||||
- wavernn
|
- wavernn
|
||||||
|
|
||||||
Raises
|
Raises:
|
||||||
------
|
|
||||||
ValueError
|
ValueError
|
||||||
If `model_type` is not one of implemented choices.
|
If `model_type` is not one of implemented choices.
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||||
c : TTS.utils.io.AttrDict
|
out_path (str): Path to save models and logging.
|
||||||
Config paramaters.
|
audio_path (str): Path to save generated test audios.
|
||||||
out_path : str
|
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console.
|
||||||
Path to save models and logging.
|
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind.
|
||||||
audio_path : str
|
|
||||||
Path to save generated test audios.
|
|
||||||
c_logger : TTS.utils.console_logger.ConsoleLogger
|
|
||||||
Class that does logging to the console.
|
|
||||||
tb_logger : TTS.utils.tensorboard.TensorboardLogger
|
|
||||||
Class that does the TensorBoard loggind.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if args.continue_path != "":
|
if args.continue_path != "":
|
||||||
args.output_path = args.continue_path
|
args.output_path = args.continue_path
|
||||||
|
@ -156,7 +146,6 @@ def process_args(args, model_type):
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
|
|
||||||
if model_type in "tacotron glow_tts speedy_speech":
|
if model_type in "tacotron glow_tts speedy_speech":
|
||||||
model_class = "TTS"
|
model_class = "TTS"
|
||||||
elif model_type in "gan wavegrad wavernn":
|
elif model_type in "gan wavegrad wavernn":
|
||||||
|
@ -192,6 +181,12 @@ def process_args(args, model_type):
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
new_fields["github_branch"] = get_git_branch()
|
new_fields["github_branch"] = get_git_branch()
|
||||||
|
# if model characters are not set in the config file
|
||||||
|
# save the default set to the config file for future
|
||||||
|
# compatibility.
|
||||||
|
if model_class == 'TTS' and not 'characters' in c.keys():
|
||||||
|
used_characters = parse_symbols()
|
||||||
|
new_fields['characters'] = used_characters
|
||||||
copy_model_files(c, args.config_path,
|
copy_model_files(c, args.config_path,
|
||||||
out_path, new_fields)
|
out_path, new_fields)
|
||||||
os.chmod(audio_path, 0o775)
|
os.chmod(audio_path, 0o775)
|
||||||
|
|
|
@ -67,7 +67,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
new_line = '"{}":"{}",\n'.format(key, value)
|
new_line = '"{}":"{}",\n'.format(key, value)
|
||||||
else:
|
else:
|
||||||
new_line = '"{}":{},\n'.format(key, value)
|
new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
|
||||||
config_lines.insert(1, new_line)
|
config_lines.insert(1, new_line)
|
||||||
config_out_file = open(copy_config_path, "w")
|
config_out_file = open(copy_config_path, "w")
|
||||||
config_out_file.writelines(config_lines)
|
config_out_file.writelines(config_lines)
|
||||||
|
|
Loading…
Reference in New Issue