config refactor #4 WIP

This commit is contained in:
Eren Gölge 2021-04-01 18:22:24 +02:00
parent 97bd5f9734
commit dc50f5f0b0
6 changed files with 229 additions and 267 deletions

View File

@ -8,6 +8,7 @@ import os
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.config_manager import ConfigManager
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config from TTS.utils.io import load_config
@ -15,26 +16,33 @@ from TTS.utils.io import load_config
def main(): def main():
"""Run preprocessing process.""" """Run preprocessing process."""
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") CONFIG = ConfigManager()
parser.add_argument(
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters." parser = argparse.ArgumentParser(
) description="Compute mean and variance of spectrogtram features.")
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).") parser.add_argument("config_path", type=str,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str,
help="save path (directory and filename).")
parser.add_argument("--data_path", type=str, required=False,
help="folder including the target set of wavs overriding dataset config.")
parser = CONFIG.init_argparse(parser)
args = parser.parse_args() args = parser.parse_args()
CONFIG.parse_argparse(args)
# load config # load config
CONFIG = load_config(args.config_path) CONFIG.load_config(args.config_path)
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization CONFIG.audio_config.signal_norm = False # do not apply earlier normalization
CONFIG.audio["stats_path"] = None # discard pre-defined stats CONFIG.audio_config.stats_path = None # discard pre-defined stats
# load audio processor # load audio processor
ap = AudioProcessor(**CONFIG.audio) ap = AudioProcessor(**CONFIG.audio_config.to_dict())
# load the meta data of target dataset # load the meta data of target dataset
if "data_path" in CONFIG.keys(): if args.data_path:
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True) dataset_items = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
else: else:
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data
print(f" > There are {len(dataset_items)} files.") print(f" > There are {len(dataset_items)} files.")
mel_sum = 0 mel_sum = 0
@ -73,14 +81,15 @@ def main():
print(f" > Avg lienar spec scale: {linear_scale.mean()}") print(f" > Avg lienar spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling # set default config values for mean-var scaling
CONFIG.audio["stats_path"] = output_file_path CONFIG.audio_config.stats_path = output_file_path
CONFIG.audio["signal_norm"] = True CONFIG.audio_config.signal_norm = True
# remove redundant values # remove redundant values
del CONFIG.audio["max_norm"] del CONFIG.audio_config.max_norm
del CONFIG.audio["min_level_db"] del CONFIG.audio_config.min_level_db
del CONFIG.audio["symmetric_norm"] del CONFIG.audio_config.symmetric_norm
del CONFIG.audio["clip_norm"] del CONFIG.audio_config.clip_norm
stats["audio_config"] = CONFIG.audio breakpoint()
stats['audio_config'] = CONFIG.audio_config.to_dict()
np.save(output_file_path, stats, allow_pickle=True) np.save(output_file_path, stats, allow_pickle=True)
print(f" > stats saved to {output_file_path}") print(f" > stats saved to {output_file_path}")

View File

@ -10,11 +10,9 @@ from random import randrange
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
@ -24,8 +22,11 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor from TTS.utils.config_manager import ConfigManager
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import ( from TTS.utils.training import (
NoamLR, NoamLR,
@ -739,7 +740,10 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") c = TacotronConfig()
args = c.init_argparse(args)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, c, model_type='tacotron')
try: try:
main(args) main(args)

View File

@ -2,10 +2,7 @@ import importlib
import re import re
from collections import Counter from collections import Counter
import numpy as np from TTS.utils.generic_utils import find_module
import torch
from TTS.utils.generic_utils import check_argument
def split_dataset(items): def split_dataset(items):
@ -39,17 +36,9 @@ def sequence_mask(sequence_length, max_len=None):
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
return text
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model)) print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower()) find_module("TTS.tts.models", c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron": if c.model.lower() in "tacotron":
model = MyModel( model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False), num_chars=num_chars + getattr(c, "add_blank", False),
@ -164,189 +153,156 @@ def is_tacotron(c):
return "tacotron" in c["model"].lower() return "tacotron" in c["model"].lower()
def check_config_tts(c): # def check_config_tts(c):
check_argument( # check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
"model", # check_argument('run_name', c, restricted=True, val_type=str)
c, # check_argument('run_description', c, val_type=str)
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
restricted=True,
val_type=str,
)
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# AUDIO # # AUDIO
# check_argument('audio', c, restricted=True, val_type=dict) # # check_argument('audio', c, restricted=True, val_type=dict)
# audio processing parameters # # audio processing parameters
# check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) # # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
# check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) # # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
# check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) # # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
# check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') # # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
# check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') # # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
# check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) # # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
# check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) # # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
# check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) # # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
# check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) # # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
# check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) # # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
# vocabulary parameters # # vocabulary parameters
check_argument("characters", c, restricted=False, val_type=dict) # check_argument('characters', c, restricted=False, val_type=dict)
check_argument( # check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str # check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
) # check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument( # check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str # check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
) # check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument(
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"characters",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
check_argument(
"phonemes",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys() and c["use_phonemes"],
val_type=str,
)
check_argument(
"punctuations",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
# normalization parameters # # normalization parameters
# check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) # # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) # # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) # # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
# check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) # # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) # # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
# check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) # # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
# check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) # # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
# check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) # # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
# check_argument('trim_db', c['audio'], restricted=True, val_type=int) # # check_argument('trim_db', c['audio'], restricted=True, val_type=int)
# training parameters # # training parameters
# check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) # # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
# check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) # # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
# check_argument('r', c, restricted=True, val_type=int, min_val=1) # # check_argument('r', c, restricted=True, val_type=int, min_val=1)
# check_argument('gradual_training', c, restricted=False, val_type=list) # # check_argument('gradual_training', c, restricted=False, val_type=list)
# check_argument('mixed_precision', c, restricted=False, val_type=bool) # # check_argument('mixed_precision', c, restricted=False, val_type=bool)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# loss parameters # # loss parameters
# check_argument('loss_masking', c, restricted=True, val_type=bool) # # check_argument('loss_masking', c, restricted=True, val_type=bool)
# if c['model'].lower() in ['tacotron', 'tacotron2']: # # if c['model'].lower() in ['tacotron', 'tacotron2']:
# check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) # # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower in ["speedy_speech", "align_tts"]: # if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) # check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) # check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) # check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
# validation parameters # # validation parameters
# check_argument('run_eval', c, restricted=True, val_type=bool) # # check_argument('run_eval', c, restricted=True, val_type=bool)
# check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) # # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
# check_argument('test_sentences_file', c, restricted=False, val_type=str) # # check_argument('test_sentences_file', c, restricted=False, val_type=str)
# optimizer # # optimizer
check_argument("noam_schedule", c, restricted=False, val_type=bool) # check_argument('noam_schedule', c, restricted=False, val_type=bool)
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0) # check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1) # check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0) # check_argument('lr', c, restricted=True, val_type=float, min_val=0)
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0) # check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) # check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool) # check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
# tacotron prenet # # tacotron prenet
# check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) # # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
# check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) # # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
# check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) # # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention # # attention
check_argument( # check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
"attention_type", # check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
c, # check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
restricted=is_tacotron(c), # check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
val_type=str, # check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
enum_list=["graves", "original", "dynamic_convolution"], # check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
) # check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int) # check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"]) # check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool) # check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool) # check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool) # check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
if c["model"].lower() in ["tacotron", "tacotron2"]: # if c['model'].lower() in ['tacotron', 'tacotron2']:
# stopnet # # stopnet
# check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) # # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) # # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models # # Model Parameters for non-tacotron models
if c["model"].lower in ["speedy_speech", "align_tts"]: # if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument("positional_encoding", c, restricted=True, val_type=type) # check_argument('positional_encoding', c, restricted=True, val_type=type)
check_argument("encoder_type", c, restricted=True, val_type=str) # check_argument('encoder_type', c, restricted=True, val_type=str)
check_argument("encoder_params", c, restricted=True, val_type=dict) # check_argument('encoder_params', c, restricted=True, val_type=dict)
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict) # check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
# GlowTTS parameters # # GlowTTS parameters
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str) # check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
# tensorboard # # tensorboard
# check_argument('print_step', c, restricted=True, val_type=int, min_val=1) # # check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) # # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('save_step', c, restricted=True, val_type=int, min_val=1) # # check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('checkpoint', c, restricted=True, val_type=bool) # # check_argument('checkpoint', c, restricted=True, val_type=bool)
# check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) # # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading # # dataloading
# pylint: disable=import-outside-toplevel # # pylint: disable=import-outside-toplevel
from TTS.tts.utils.text import cleaners # from TTS.tts.utils.text import cleaners
# check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) # # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
# check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) # # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
# check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) # # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) # # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) # # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
# check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) # # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
# check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) # # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) # # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
# paths # # paths
# check_argument('output_path', c, restricted=True, val_type=str) # # check_argument('output_path', c, restricted=True, val_type=str)
# multi-speaker and gst # # multi-speaker and gst
# check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) # # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
# check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) # # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
# check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) # # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: # if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
# check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) # # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) # # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
# check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) # # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
# check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) # # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
# check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) # # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
# check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) # # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
# check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) # # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry # # datasets - checking only the first entry
# check_argument('datasets', c, restricted=True, val_type=list) # # check_argument('datasets', c, restricted=True, val_type=list)
# for dataset_entry in c['datasets']: # # for dataset_entry in c['datasets']:
# check_argument('name', dataset_entry, restricted=True, val_type=str) # # check_argument('name', dataset_entry, restricted=True, val_type=str)
# check_argument('path', dataset_entry, restricted=True, val_type=str) # # check_argument('path', dataset_entry, restricted=True, val_type=str)
# check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) # # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
# check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) # # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)

View File

@ -8,12 +8,10 @@ import json
import os import os
import re import re
import torch
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files, load_config from TTS.utils.io import copy_model_files
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
@ -140,11 +138,11 @@ def process_args(args, config, tb_prefix):
if not args.best_path: if not args.best_path:
args.best_path = best_model args.best_path = best_model
# setup output paths and read configs # setup output paths and read configs
c = config.load_json(args.config_path) config.load_json(args.config_path)
if c.mixed_precision: if config.mixed_precision:
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
if not os.path.exists(c.output_path): if not os.path.exists(config.output_path):
out_path = create_experiment_folder(c.output_path, c.run_name, out_path = create_experiment_folder(config.output_path, config.run_name,
args.debug) args.debug)
audio_path = os.path.join(out_path, "test_audios") audio_path = os.path.join(out_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
@ -157,7 +155,7 @@ def process_args(args, config, tb_prefix):
# if model characters are not set in the config file # if model characters are not set in the config file
# save the default set to the config file for future # save the default set to the config file for future
# compatibility. # compatibility.
if c.has('characters_config'): if config.has('characters_config'):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(c, args.config_path, out_path, new_fields) copy_model_files(c, args.config_path, out_path, new_fields)
@ -166,6 +164,6 @@ def process_args(args, config, tb_prefix):
log_path = out_path log_path = out_path
tb_logger = TensorboardLogger(log_path, model_name=tb_prefix) tb_logger = TensorboardLogger(log_path, model_name=tb_prefix)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", c["run_description"], 0) tb_logger.tb_add_text("model-description", config["run_description"], 0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return c, out_path, audio_path, c_logger, tb_logger return c, out_path, audio_path, c_logger, tb_logger

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime import datetime
import glob import glob
import importlib
import os import os
import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
@ -67,6 +71,20 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_camel(text):
text = text.capitalize()
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
text = text.replace('Tts', 'TTS')
return text
def find_module(module_path: str, module_name: str) -> object:
module_name = module_name.lower()
module = importlib.import_module(module_path+'.'+module_name)
class_name = to_camel(module_name)
return getattr(module, class_name)
def get_user_data_dir(appname): def get_user_data_dir(appname):
if sys.platform == "win32": if sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel import winreg # pylint: disable=import-outside-toplevel
@ -139,32 +157,3 @@ class KeepAverage:
for key, value in value_dict.items(): for key, value in value_dict.items():
self.update_value(key, value) self.update_value(key, value)
def check_argument(name,
c,
prerequest=None,
enum_list=None,
max_val=None,
min_val=None,
restricted=False,
alternative=None,
allow_none=False):
if isinstance(prerequest, List()):
if any([f not in c.keys() for f in prerequest]):
return
else:
if prerequest not in c.keys():
return
if alternative in c.keys() and c[alternative] is not None:
return
if allow_none and c[name] is None:
return
if restricted:
assert name in c.keys(), f" [!] {name} not defined in config.json"
if name in c.keys():
if max_val:
assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
if min_val:
assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
if enum_list:
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'

View File

@ -3,6 +3,7 @@ import os
import pickle as pickle_tts import pickle as pickle_tts
import re import re
from shutil import copyfile from shutil import copyfile
from TTS.utils.generic_utils import find_module
import yaml import yaml
@ -23,32 +24,37 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
# def read_json_with_comments(json_path): def read_json_with_comments(json_path):
# # fallback to json """DEPRECATED"""
# with open(json_path, "r", encoding="utf-8") as f: # fallback to json
# input_str = f.read() with open(json_path, "r", encoding="utf-8") as f:
# # handle comments input_str = f.read()
# input_str = re.sub(r'\\\n', '', input_str) # handle comments
# input_str = re.sub(r'//.*\n', '\n', input_str) input_str = re.sub(r'\\\n', '', input_str)
# data = json.loads(input_str) input_str = re.sub(r'//.*\n', '\n', input_str)
# return data data = json.loads(input_str)
return data
# def load_config(config_path: str) -> AttrDict: def load_config(config_path: str) -> AttrDict:
# """Load config files and discard comments """DEPRECATED: Load config files and discard comments
# Args: Args:
# config_path (str): path to config file. config_path (str): path to config file.
# """ """
# config = AttrDict() config_dict = AttrDict()
ext = os.path.splitext(config_path)[1]
# ext = os.path.splitext(config_path)[1] if ext in (".yml", ".yaml"):
# # if ext in (".yml", ".yaml"): with open(config_path, "r", encoding="utf-8") as f:
# # with open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f)
# # data = yaml.safe_load(f) else:
# # else: with open(config_path, "r", encoding="utf-8") as f:
# data = read_json_with_comments(config_path) input_str = f.read()
# config.update(data) data = json.loads(input_str)
# return config config_dict.update(data)
config_class = find_module('TTS.tts.configs', config_dict.model.lower()+'_config')
config = config_class()
config.from_dict(config_dict)
return
def copy_model_files(c, config_file, out_path, new_fields): def copy_model_files(c, config_file, out_path, new_fields):