config refactor #4 WIP

This commit is contained in:
Eren Gölge 2021-04-01 18:22:24 +02:00
parent 97bd5f9734
commit dc50f5f0b0
6 changed files with 229 additions and 267 deletions

View File

@ -8,6 +8,7 @@ import os
import numpy as np
from tqdm import tqdm
from TTS.utils.config_manager import ConfigManager
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
@ -15,26 +16,33 @@ from TTS.utils.io import load_config
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument(
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters."
)
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).")
CONFIG = ConfigManager()
parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str,
help="save path (directory and filename).")
parser.add_argument("--data_path", type=str, required=False,
help="folder including the target set of wavs overriding dataset config.")
parser = CONFIG.init_argparse(parser)
args = parser.parse_args()
CONFIG.parse_argparse(args)
# load config
CONFIG = load_config(args.config_path)
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
CONFIG.audio["stats_path"] = None # discard pre-defined stats
CONFIG.load_config(args.config_path)
CONFIG.audio_config.signal_norm = False # do not apply earlier normalization
CONFIG.audio_config.stats_path = None # discard pre-defined stats
# load audio processor
ap = AudioProcessor(**CONFIG.audio)
ap = AudioProcessor(**CONFIG.audio_config.to_dict())
# load the meta data of target dataset
if "data_path" in CONFIG.keys():
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True)
if args.data_path:
dataset_items = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
else:
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data
print(f" > There are {len(dataset_items)} files.")
mel_sum = 0
@ -73,14 +81,15 @@ def main():
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling
CONFIG.audio["stats_path"] = output_file_path
CONFIG.audio["signal_norm"] = True
CONFIG.audio_config.stats_path = output_file_path
CONFIG.audio_config.signal_norm = True
# remove redundant values
del CONFIG.audio["max_norm"]
del CONFIG.audio["min_level_db"]
del CONFIG.audio["symmetric_norm"]
del CONFIG.audio["clip_norm"]
stats["audio_config"] = CONFIG.audio
del CONFIG.audio_config.max_norm
del CONFIG.audio_config.min_level_db
del CONFIG.audio_config.symmetric_norm
del CONFIG.audio_config.clip_norm
breakpoint()
stats['audio_config'] = CONFIG.audio_config.to_dict()
np.save(output_file_path, stats, allow_pickle=True)
print(f" > stats saved to {output_file_path}")

View File

@ -10,11 +10,9 @@ from random import randrange
import numpy as np
import torch
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
@ -24,8 +22,11 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.config_manager import ConfigManager
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.radam import RAdam
from TTS.utils.training import (
NoamLR,
@ -739,7 +740,10 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
c = TacotronConfig()
args = c.init_argparse(args)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, c, model_type='tacotron')
try:
main(args)

View File

@ -2,10 +2,7 @@ import importlib
import re
from collections import Counter
import numpy as np
import torch
from TTS.utils.generic_utils import check_argument
from TTS.utils.generic_utils import find_module
def split_dataset(items):
@ -39,17 +36,9 @@ def sequence_mask(sequence_length, max_len=None):
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
return text
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
find_module("TTS.tts.models", c.model.lower())
if c.model.lower() in "tacotron":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
@ -164,189 +153,156 @@ def is_tacotron(c):
return "tacotron" in c["model"].lower()
def check_config_tts(c):
check_argument(
"model",
c,
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
restricted=True,
val_type=str,
)
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# def check_config_tts(c):
# check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
# check_argument('run_name', c, restricted=True, val_type=str)
# check_argument('run_description', c, val_type=str)
# AUDIO
# check_argument('audio', c, restricted=True, val_type=dict)
# # AUDIO
# # check_argument('audio', c, restricted=True, val_type=dict)
# audio processing parameters
# check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
# check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
# check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
# check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
# check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
# check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
# check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
# check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
# check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
# check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
# # audio processing parameters
# # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
# # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
# # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
# # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
# # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
# # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
# # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
# # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
# # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
# # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
# vocabulary parameters
check_argument("characters", c, restricted=False, val_type=dict)
check_argument(
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"characters",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
check_argument(
"phonemes",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys() and c["use_phonemes"],
val_type=str,
)
check_argument(
"punctuations",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
# # vocabulary parameters
# check_argument('characters', c, restricted=False, val_type=dict)
# check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
# check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# normalization parameters
# check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
# check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
# check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
# check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
# check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
# check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
# check_argument('trim_db', c['audio'], restricted=True, val_type=int)
# # normalization parameters
# # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
# # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
# # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
# # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
# # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
# # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
# # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
# # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
# # check_argument('trim_db', c['audio'], restricted=True, val_type=int)
# training parameters
# check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
# check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
# check_argument('r', c, restricted=True, val_type=int, min_val=1)
# check_argument('gradual_training', c, restricted=False, val_type=list)
# check_argument('mixed_precision', c, restricted=False, val_type=bool)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# # training parameters
# # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
# # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
# # check_argument('r', c, restricted=True, val_type=int, min_val=1)
# # check_argument('gradual_training', c, restricted=False, val_type=list)
# # check_argument('mixed_precision', c, restricted=False, val_type=bool)
# # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# loss parameters
# check_argument('loss_masking', c, restricted=True, val_type=bool)
# if c['model'].lower() in ['tacotron', 'tacotron2']:
# check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
# # loss parameters
# # check_argument('loss_masking', c, restricted=True, val_type=bool)
# # if c['model'].lower() in ['tacotron', 'tacotron2']:
# # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
# if c['model'].lower in ["speedy_speech", "align_tts"]:
# check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
# check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
# validation parameters
# check_argument('run_eval', c, restricted=True, val_type=bool)
# check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
# check_argument('test_sentences_file', c, restricted=False, val_type=str)
# # validation parameters
# # check_argument('run_eval', c, restricted=True, val_type=bool)
# # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
# # check_argument('test_sentences_file', c, restricted=False, val_type=str)
# optimizer
check_argument("noam_schedule", c, restricted=False, val_type=bool)
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)
# # optimizer
# check_argument('noam_schedule', c, restricted=False, val_type=bool)
# check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
# check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
# check_argument('lr', c, restricted=True, val_type=float, min_val=0)
# check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
# check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
# check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
# tacotron prenet
# check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
# check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
# check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# # tacotron prenet
# # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
# # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
# # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument(
"attention_type",
c,
restricted=is_tacotron(c),
val_type=str,
enum_list=["graves", "original", "dynamic_convolution"],
)
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
# # attention
# check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
# check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
# check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
# check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
if c["model"].lower() in ["tacotron", "tacotron2"]:
# stopnet
# check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# if c['model'].lower() in ['tacotron', 'tacotron2']:
# # stopnet
# # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
# # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("positional_encoding", c, restricted=True, val_type=type)
check_argument("encoder_type", c, restricted=True, val_type=str)
check_argument("encoder_params", c, restricted=True, val_type=dict)
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)
# # Model Parameters for non-tacotron models
# if c['model'].lower in ["speedy_speech", "align_tts"]:
# check_argument('positional_encoding', c, restricted=True, val_type=type)
# check_argument('encoder_type', c, restricted=True, val_type=str)
# check_argument('encoder_params', c, restricted=True, val_type=dict)
# check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
# GlowTTS parameters
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
# # GlowTTS parameters
# check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
# check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('checkpoint', c, restricted=True, val_type=bool)
# check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# # tensorboard
# # check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
# # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
# # check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
# # check_argument('checkpoint', c, restricted=True, val_type=bool)
# # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading
# pylint: disable=import-outside-toplevel
from TTS.tts.utils.text import cleaners
# check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
# check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
# check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
# check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
# check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
# # dataloading
# # pylint: disable=import-outside-toplevel
# from TTS.tts.utils.text import cleaners
# # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
# # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
# # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
# # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
# # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
# # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
# # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
# paths
# check_argument('output_path', c, restricted=True, val_type=str)
# # paths
# # check_argument('output_path', c, restricted=True, val_type=str)
# multi-speaker and gst
# check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
# check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
# check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
# check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
# check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
# check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
# check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
# check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
# check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
# check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# # multi-speaker and gst
# # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
# # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
# # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
# if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
# # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
# # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
# # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
# # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
# # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
# # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
# # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry
# check_argument('datasets', c, restricted=True, val_type=list)
# for dataset_entry in c['datasets']:
# check_argument('name', dataset_entry, restricted=True, val_type=str)
# check_argument('path', dataset_entry, restricted=True, val_type=str)
# check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
# check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
# # datasets - checking only the first entry
# # check_argument('datasets', c, restricted=True, val_type=list)
# # for dataset_entry in c['datasets']:
# # check_argument('name', dataset_entry, restricted=True, val_type=str)
# # check_argument('path', dataset_entry, restricted=True, val_type=str)
# # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
# # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)

View File

@ -8,12 +8,10 @@ import json
import os
import re
import torch
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.io import copy_model_files
from TTS.utils.tensorboard_logger import TensorboardLogger
@ -140,11 +138,11 @@ def process_args(args, config, tb_prefix):
if not args.best_path:
args.best_path = best_model
# setup output paths and read configs
c = config.load_json(args.config_path)
if c.mixed_precision:
config.load_json(args.config_path)
if config.mixed_precision:
print(" > Mixed precision mode is ON")
if not os.path.exists(c.output_path):
out_path = create_experiment_folder(c.output_path, c.run_name,
if not os.path.exists(config.output_path):
out_path = create_experiment_folder(config.output_path, config.run_name,
args.debug)
audio_path = os.path.join(out_path, "test_audios")
# setup rank 0 process in distributed training
@ -157,7 +155,7 @@ def process_args(args, config, tb_prefix):
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if c.has('characters_config'):
if config.has('characters_config'):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(c, args.config_path, out_path, new_fields)
@ -166,6 +164,6 @@ def process_args(args, config, tb_prefix):
log_path = out_path
tb_logger = TensorboardLogger(log_path, model_name=tb_prefix)
# write model desc to tensorboard
tb_logger.tb_add_text("model-description", c["run_description"], 0)
tb_logger.tb_add_text("model-description", config["run_description"], 0)
c_logger = ConsoleLogger()
return c, out_path, audio_path, c_logger, tb_logger

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
@ -67,6 +71,20 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_camel(text):
text = text.capitalize()
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
text = text.replace('Tts', 'TTS')
return text
def find_module(module_path: str, module_name: str) -> object:
module_name = module_name.lower()
module = importlib.import_module(module_path+'.'+module_name)
class_name = to_camel(module_name)
return getattr(module, class_name)
def get_user_data_dir(appname):
if sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
@ -139,32 +157,3 @@ class KeepAverage:
for key, value in value_dict.items():
self.update_value(key, value)
def check_argument(name,
c,
prerequest=None,
enum_list=None,
max_val=None,
min_val=None,
restricted=False,
alternative=None,
allow_none=False):
if isinstance(prerequest, List()):
if any([f not in c.keys() for f in prerequest]):
return
else:
if prerequest not in c.keys():
return
if alternative in c.keys() and c[alternative] is not None:
return
if allow_none and c[name] is None:
return
if restricted:
assert name in c.keys(), f" [!] {name} not defined in config.json"
if name in c.keys():
if max_val:
assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
if min_val:
assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
if enum_list:
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'

View File

@ -3,6 +3,7 @@ import os
import pickle as pickle_tts
import re
from shutil import copyfile
from TTS.utils.generic_utils import find_module
import yaml
@ -23,32 +24,37 @@ class AttrDict(dict):
self.__dict__ = self
# def read_json_with_comments(json_path):
# # fallback to json
# with open(json_path, "r", encoding="utf-8") as f:
# input_str = f.read()
# # handle comments
# input_str = re.sub(r'\\\n', '', input_str)
# input_str = re.sub(r'//.*\n', '\n', input_str)
# data = json.loads(input_str)
# return data
def read_json_with_comments(json_path):
"""DEPRECATED"""
# fallback to json
with open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
return data
# def load_config(config_path: str) -> AttrDict:
# """Load config files and discard comments
def load_config(config_path: str) -> AttrDict:
"""DEPRECATED: Load config files and discard comments
# Args:
# config_path (str): path to config file.
# """
# config = AttrDict()
# ext = os.path.splitext(config_path)[1]
# # if ext in (".yml", ".yaml"):
# # with open(config_path, "r", encoding="utf-8") as f:
# # data = yaml.safe_load(f)
# # else:
# data = read_json_with_comments(config_path)
# config.update(data)
# return config
Args:
config_path (str): path to config file.
"""
config_dict = AttrDict()
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
else:
with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read()
data = json.loads(input_str)
config_dict.update(data)
config_class = find_module('TTS.tts.configs', config_dict.model.lower()+'_config')
config = config_class()
config.from_dict(config_dict)
return
def copy_model_files(c, config_file, out_path, new_fields):