From 647163397db6f2c94bdd4a4046f7ccfa0154a4bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 5 May 2021 02:31:12 +0200 Subject: [PATCH] coqpit refactoring --- TTS/tts/utils/synthesis.py | 8 ++++---- TTS/utils/arguments.py | 12 +++++------- TTS/utils/audio.py | 1 + TTS/utils/generic_utils.py | 2 +- TTS/utils/io.py | 22 ++++++---------------- 5 files changed, 17 insertions(+), 28 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 281ef55d..405cf2dc 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -23,8 +23,8 @@ def text_to_seqvec(text, CONFIG): text_cleaner, CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if "characters" in CONFIG.keys() else None, - add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + tp=CONFIG.characters, + add_blank=CONFIG.add_blank, ), dtype=np.int32, ) @@ -33,8 +33,8 @@ def text_to_seqvec(text, CONFIG): text_to_sequence( text, text_cleaner, - tp=CONFIG.characters if "characters" in CONFIG.keys() else None, - add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + tp=CONFIG.characters, + add_blank=CONFIG.add_blank, ), dtype=np.int32, ) diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 344a8ba6..85e89191 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -2,12 +2,11 @@ # -*- coding: utf-8 -*- """Argument parser for training scripts.""" +import torch import argparse import glob -import json import os import re -import sys from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.console_logger import ConsoleLogger @@ -146,10 +145,9 @@ def process_args(args): config.parse_args(coqpit_overrides) if config.mixed_precision: print(" > Mixed precision mode is ON") - if not os.path.exists(config.output_path): + experiment_path = args.continue_path + if not experiment_path: experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug) - else: - experiment_path = config.output_path audio_path = os.path.join(experiment_path, "test_audios") # setup rank 0 process in distributed training if args.rank == 0: @@ -164,12 +162,12 @@ def process_args(args): if config.has("characters_config"): used_characters = parse_symbols() new_fields["characters"] = used_characters - copy_model_files(config, args.config_path, experiment_path, new_fields) + copy_model_files(config, experiment_path, new_fields) os.chmod(audio_path, 0o775) os.chmod(experiment_path, 0o775) tb_logger = TensorboardLogger(experiment_path, model_name=config.model) # write model desc to tensorboard - tb_logger.tb_add_text("model-description", config["run_description"], 0) + tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, tb_logger diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index cb1341f4..222b4c74 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -21,6 +21,7 @@ class AudioProcessor(object): sample_rate (int, optional): target audio sampling rate. Defaults to None. resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. num_mels (int, optional): number of melspectrogram dimensions. Defaults to None. + log_func (int, optional): log exponent used for converting spectrogram aplitude to DB. min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None. frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None. frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None. diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 274fb634..92244b8b 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -111,7 +111,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # 2. filter out different size layers pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} # 3. skip reinit layers - if c.reinit_layers is not None: + if c.has('reinit_layers') and c.reinit_layers is not None: for reinit_layer_name in c.reinit_layers: pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 0a352d0d..6d233d24 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -58,35 +58,25 @@ def load_config(config_path: str) -> None: return config -def copy_model_files(c, config_file, out_path, new_fields): +def copy_model_files(config, out_path, new_fields): """Copy config.json and other model files to training folder and add new fields. Args: - c (dict): model config from config.json. - config_file (str): path to config file. + config (Coqpit): Coqpit config defining the training run. out_path (str): output path to copy the file. new_fields (dict): new fileds to be added or edited in the config file. """ - # copy config.json copy_config_path = os.path.join(out_path, "config.json") - config_lines = open(config_file, "r", encoding="utf-8").readlines() # add extra information fields - for key, value in new_fields.items(): - if isinstance(value, str): - new_line = '"{}":"{}",\n'.format(key, value) - else: - new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False)) - config_lines.insert(1, new_line) - config_out_file = open(copy_config_path, "w", encoding="utf-8") - config_out_file.writelines(config_lines) - config_out_file.close() + config.update(new_fields, allow_new=True) + config.save_json(copy_config_path) # copy model stats file if available - if c.audio["stats_path"] is not None: + if config.audio.stats_path is not None: copy_stats_path = os.path.join(out_path, "scale_stats.npy") if not os.path.exists(copy_stats_path): copyfile( - c.audio["stats_path"], + config.audio.stats_path, copy_stats_path, )