coqpit refactoring

This commit is contained in:
Eren Gölge 2021-05-05 02:31:12 +02:00
parent eaa130e813
commit 647163397d
5 changed files with 17 additions and 28 deletions

View File

@ -23,8 +23,8 @@ def text_to_seqvec(text, CONFIG):
text_cleaner, text_cleaner,
CONFIG.phoneme_language, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars, CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None, tp=CONFIG.characters,
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, add_blank=CONFIG.add_blank,
), ),
dtype=np.int32, dtype=np.int32,
) )
@ -33,8 +33,8 @@ def text_to_seqvec(text, CONFIG):
text_to_sequence( text_to_sequence(
text, text,
text_cleaner, text_cleaner,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None, tp=CONFIG.characters,
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, add_blank=CONFIG.add_blank,
), ),
dtype=np.int32, dtype=np.int32,
) )

View File

@ -2,12 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Argument parser for training scripts.""" """Argument parser for training scripts."""
import torch
import argparse import argparse
import glob import glob
import json
import os import os
import re import re
import sys
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
@ -146,10 +145,9 @@ def process_args(args):
config.parse_args(coqpit_overrides) config.parse_args(coqpit_overrides)
if config.mixed_precision: if config.mixed_precision:
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
if not os.path.exists(config.output_path): experiment_path = args.continue_path
if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug) experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
else:
experiment_path = config.output_path
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
if args.rank == 0: if args.rank == 0:
@ -164,12 +162,12 @@ def process_args(args):
if config.has("characters_config"): if config.has("characters_config"):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, args.config_path, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", config["run_description"], 0) tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger

View File

@ -21,6 +21,7 @@ class AudioProcessor(object):
sample_rate (int, optional): target audio sampling rate. Defaults to None. sample_rate (int, optional): target audio sampling rate. Defaults to None.
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None. num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None. min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None. frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None. frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.

View File

@ -111,7 +111,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# 2. filter out different size layers # 2. filter out different size layers
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers # 3. skip reinit layers
if c.reinit_layers is not None: if c.has('reinit_layers') and c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers: for reinit_layer_name in c.reinit_layers:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict # 4. overwrite entries in the existing state dict

View File

@ -58,35 +58,25 @@ def load_config(config_path: str) -> None:
return config return config
def copy_model_files(c, config_file, out_path, new_fields): def copy_model_files(config, out_path, new_fields):
"""Copy config.json and other model files to training folder and add """Copy config.json and other model files to training folder and add
new fields. new fields.
Args: Args:
c (dict): model config from config.json. config (Coqpit): Coqpit config defining the training run.
config_file (str): path to config file.
out_path (str): output path to copy the file. out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited new_fields (dict): new fileds to be added or edited
in the config file. in the config file.
""" """
# copy config.json
copy_config_path = os.path.join(out_path, "config.json") copy_config_path = os.path.join(out_path, "config.json")
config_lines = open(config_file, "r", encoding="utf-8").readlines()
# add extra information fields # add extra information fields
for key, value in new_fields.items(): config.update(new_fields, allow_new=True)
if isinstance(value, str): config.save_json(copy_config_path)
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
config_lines.insert(1, new_line)
config_out_file = open(copy_config_path, "w", encoding="utf-8")
config_out_file.writelines(config_lines)
config_out_file.close()
# copy model stats file if available # copy model stats file if available
if c.audio["stats_path"] is not None: if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy") copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path): if not os.path.exists(copy_stats_path):
copyfile( copyfile(
c.audio["stats_path"], config.audio.stats_path,
copy_stats_path, copy_stats_path,
) )