mirror of https://github.com/coqui-ai/TTS.git
coqpit refactoring
This commit is contained in:
parent
eaa130e813
commit
647163397d
|
@ -23,8 +23,8 @@ def text_to_seqvec(text, CONFIG):
|
||||||
text_cleaner,
|
text_cleaner,
|
||||||
CONFIG.phoneme_language,
|
CONFIG.phoneme_language,
|
||||||
CONFIG.enable_eos_bos_chars,
|
CONFIG.enable_eos_bos_chars,
|
||||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
tp=CONFIG.characters,
|
||||||
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
|
add_blank=CONFIG.add_blank,
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
@ -33,8 +33,8 @@ def text_to_seqvec(text, CONFIG):
|
||||||
text_to_sequence(
|
text_to_sequence(
|
||||||
text,
|
text,
|
||||||
text_cleaner,
|
text_cleaner,
|
||||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
tp=CONFIG.characters,
|
||||||
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
|
add_blank=CONFIG.add_blank,
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Argument parser for training scripts."""
|
"""Argument parser for training scripts."""
|
||||||
|
|
||||||
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
|
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols
|
from TTS.tts.utils.text.symbols import parse_symbols
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
@ -146,10 +145,9 @@ def process_args(args):
|
||||||
config.parse_args(coqpit_overrides)
|
config.parse_args(coqpit_overrides)
|
||||||
if config.mixed_precision:
|
if config.mixed_precision:
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
if not os.path.exists(config.output_path):
|
experiment_path = args.continue_path
|
||||||
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||||
else:
|
|
||||||
experiment_path = config.output_path
|
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
|
@ -164,12 +162,12 @@ def process_args(args):
|
||||||
if config.has("characters_config"):
|
if config.has("characters_config"):
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, args.config_path, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
os.chmod(audio_path, 0o775)
|
os.chmod(audio_path, 0o775)
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-description", config["run_description"], 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ class AudioProcessor(object):
|
||||||
sample_rate (int, optional): target audio sampling rate. Defaults to None.
|
sample_rate (int, optional): target audio sampling rate. Defaults to None.
|
||||||
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||||
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
|
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
|
||||||
|
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
|
||||||
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
|
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||||
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
|
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
|
||||||
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
|
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
|
||||||
|
|
|
@ -111,7 +111,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
# 2. filter out different size layers
|
# 2. filter out different size layers
|
||||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||||
# 3. skip reinit layers
|
# 3. skip reinit layers
|
||||||
if c.reinit_layers is not None:
|
if c.has('reinit_layers') and c.reinit_layers is not None:
|
||||||
for reinit_layer_name in c.reinit_layers:
|
for reinit_layer_name in c.reinit_layers:
|
||||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||||
# 4. overwrite entries in the existing state dict
|
# 4. overwrite entries in the existing state dict
|
||||||
|
|
|
@ -58,35 +58,25 @@ def load_config(config_path: str) -> None:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(c, config_file, out_path, new_fields):
|
def copy_model_files(config, out_path, new_fields):
|
||||||
"""Copy config.json and other model files to training folder and add
|
"""Copy config.json and other model files to training folder and add
|
||||||
new fields.
|
new fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
c (dict): model config from config.json.
|
config (Coqpit): Coqpit config defining the training run.
|
||||||
config_file (str): path to config file.
|
|
||||||
out_path (str): output path to copy the file.
|
out_path (str): output path to copy the file.
|
||||||
new_fields (dict): new fileds to be added or edited
|
new_fields (dict): new fileds to be added or edited
|
||||||
in the config file.
|
in the config file.
|
||||||
"""
|
"""
|
||||||
# copy config.json
|
|
||||||
copy_config_path = os.path.join(out_path, "config.json")
|
copy_config_path = os.path.join(out_path, "config.json")
|
||||||
config_lines = open(config_file, "r", encoding="utf-8").readlines()
|
|
||||||
# add extra information fields
|
# add extra information fields
|
||||||
for key, value in new_fields.items():
|
config.update(new_fields, allow_new=True)
|
||||||
if isinstance(value, str):
|
config.save_json(copy_config_path)
|
||||||
new_line = '"{}":"{}",\n'.format(key, value)
|
|
||||||
else:
|
|
||||||
new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
|
|
||||||
config_lines.insert(1, new_line)
|
|
||||||
config_out_file = open(copy_config_path, "w", encoding="utf-8")
|
|
||||||
config_out_file.writelines(config_lines)
|
|
||||||
config_out_file.close()
|
|
||||||
# copy model stats file if available
|
# copy model stats file if available
|
||||||
if c.audio["stats_path"] is not None:
|
if config.audio.stats_path is not None:
|
||||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||||
if not os.path.exists(copy_stats_path):
|
if not os.path.exists(copy_stats_path):
|
||||||
copyfile(
|
copyfile(
|
||||||
c.audio["stats_path"],
|
config.audio.stats_path,
|
||||||
copy_stats_path,
|
copy_stats_path,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue