save default model chars to the training config file

This commit is contained in:
Eren Gölge 2021-02-12 12:06:46 +00:00 committed by Eren Gölge
parent 62a8eba3b2
commit 194f82de51
2 changed files with 34 additions and 39 deletions

View File

@ -3,17 +3,18 @@
"""Argument parser for training scripts.""" """Argument parser for training scripts."""
import argparse import argparse
import re
import glob import glob
import os import os
import re
from TTS.utils.generic_utils import ( import json
create_experiment_folder, get_git_branch)
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.tts.utils.generic_utils import check_config_tts from TTS.tts.utils.generic_utils import check_config_tts
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import (copy_model_files, load_config,
save_characters_to_config)
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.tts.utils.text.symbols import parse_symbols
def parse_arguments(argv): def parse_arguments(argv):
@ -110,12 +111,9 @@ def get_last_checkpoint(path):
def process_args(args, model_type): def process_args(args, model_type):
"""Process parsed comand line arguments. """Process parsed comand line arguments.
Parameters Args:
---------- args (argparse.Namespace or dict like): Parsed input arguments.
args : argparse.Namespace or dict like model_type (str): Model type used to check config parameters and setup the TensorBoard
Parsed input arguments.
model_type : str
Model type used to check config parameters and setup the TensorBoard
logger. One of: logger. One of:
- tacotron - tacotron
- glow_tts - glow_tts
@ -124,24 +122,16 @@ def process_args(args, model_type):
- wavegrad - wavegrad
- wavernn - wavernn
Raises Raises:
------
ValueError ValueError
If `model_type` is not one of implemented choices. If `model_type` is not one of implemented choices.
Returns Returns:
------- c (TTS.utils.io.AttrDict): Config paramaters.
c : TTS.utils.io.AttrDict out_path (str): Path to save models and logging.
Config paramaters. audio_path (str): Path to save generated test audios.
out_path : str c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console.
Path to save models and logging. tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind.
audio_path : str
Path to save generated test audios.
c_logger : TTS.utils.console_logger.ConsoleLogger
Class that does logging to the console.
tb_logger : TTS.utils.tensorboard.TensorboardLogger
Class that does the TensorBoard loggind.
""" """
if args.continue_path != "": if args.continue_path != "":
args.output_path = args.continue_path args.output_path = args.continue_path
@ -156,7 +146,6 @@ def process_args(args, model_type):
# setup output paths and read configs # setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
if model_type in "tacotron glow_tts speedy_speech": if model_type in "tacotron glow_tts speedy_speech":
model_class = "TTS" model_class = "TTS"
elif model_type in "gan wavegrad wavernn": elif model_type in "gan wavegrad wavernn":
@ -192,6 +181,12 @@ def process_args(args, model_type):
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch() new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if model_class == 'TTS' and not 'characters' in c.keys():
used_characters = parse_symbols()
new_fields['characters'] = used_characters
copy_model_files(c, args.config_path, copy_model_files(c, args.config_path,
out_path, new_fields) out_path, new_fields)
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)

View File

@ -67,7 +67,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
if isinstance(value, str): if isinstance(value, str):
new_line = '"{}":"{}",\n'.format(key, value) new_line = '"{}":"{}",\n'.format(key, value)
else: else:
new_line = '"{}":{},\n'.format(key, value) new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
config_lines.insert(1, new_line) config_lines.insert(1, new_line)
config_out_file = open(copy_config_path, "w") config_out_file = open(copy_config_path, "w")
config_out_file.writelines(config_lines) config_out_file.writelines(config_lines)