mirror of https://github.com/coqui-ai/TTS.git
config refactor #4 WIP
This commit is contained in:
parent
97bd5f9734
commit
dc50f5f0b0
|
@ -8,6 +8,7 @@ import os
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.config_manager import ConfigManager
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
|
@ -15,26 +16,33 @@ from TTS.utils.io import load_config
|
|||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters."
|
||||
)
|
||||
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).")
|
||||
CONFIG = ConfigManager()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument("config_path", type=str,
|
||||
help="TTS config file path to define audio processin parameters.")
|
||||
parser.add_argument("out_path", type=str,
|
||||
help="save path (directory and filename).")
|
||||
parser.add_argument("--data_path", type=str, required=False,
|
||||
help="folder including the target set of wavs overriding dataset config.")
|
||||
parser = CONFIG.init_argparse(parser)
|
||||
args = parser.parse_args()
|
||||
CONFIG.parse_argparse(args)
|
||||
|
||||
# load config
|
||||
CONFIG = load_config(args.config_path)
|
||||
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
|
||||
CONFIG.audio["stats_path"] = None # discard pre-defined stats
|
||||
CONFIG.load_config(args.config_path)
|
||||
CONFIG.audio_config.signal_norm = False # do not apply earlier normalization
|
||||
CONFIG.audio_config.stats_path = None # discard pre-defined stats
|
||||
|
||||
# load audio processor
|
||||
ap = AudioProcessor(**CONFIG.audio)
|
||||
ap = AudioProcessor(**CONFIG.audio_config.to_dict())
|
||||
|
||||
# load the meta data of target dataset
|
||||
if "data_path" in CONFIG.keys():
|
||||
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True)
|
||||
if args.data_path:
|
||||
dataset_items = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
|
||||
else:
|
||||
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
|
||||
dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data
|
||||
print(f" > There are {len(dataset_items)} files.")
|
||||
|
||||
mel_sum = 0
|
||||
|
@ -73,14 +81,15 @@ def main():
|
|||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
||||
|
||||
# set default config values for mean-var scaling
|
||||
CONFIG.audio["stats_path"] = output_file_path
|
||||
CONFIG.audio["signal_norm"] = True
|
||||
CONFIG.audio_config.stats_path = output_file_path
|
||||
CONFIG.audio_config.signal_norm = True
|
||||
# remove redundant values
|
||||
del CONFIG.audio["max_norm"]
|
||||
del CONFIG.audio["min_level_db"]
|
||||
del CONFIG.audio["symmetric_norm"]
|
||||
del CONFIG.audio["clip_norm"]
|
||||
stats["audio_config"] = CONFIG.audio
|
||||
del CONFIG.audio_config.max_norm
|
||||
del CONFIG.audio_config.min_level_db
|
||||
del CONFIG.audio_config.symmetric_norm
|
||||
del CONFIG.audio_config.clip_norm
|
||||
breakpoint()
|
||||
stats['audio_config'] = CONFIG.audio_config.to_dict()
|
||||
np.save(output_file_path, stats, allow_pickle=True)
|
||||
print(f" > stats saved to {output_file_path}")
|
||||
|
||||
|
|
|
@ -10,11 +10,9 @@ from random import randrange
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
|
@ -24,8 +22,11 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.config_manager import ConfigManager
|
||||
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import (
|
||||
NoamLR,
|
||||
|
@ -739,7 +740,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
|
||||
c = TacotronConfig()
|
||||
args = c.init_argparse(args)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, c, model_type='tacotron')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -2,10 +2,7 @@ import importlib
|
|||
import re
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import check_argument
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
|
@ -39,17 +36,9 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
text = text.replace("Tts", "TTS")
|
||||
return text
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.model))
|
||||
find_module("TTS.tts.models", c.model.lower())
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
|
@ -164,189 +153,156 @@ def is_tacotron(c):
|
|||
return "tacotron" in c["model"].lower()
|
||||
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument(
|
||||
"model",
|
||||
c,
|
||||
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
|
||||
restricted=True,
|
||||
val_type=str,
|
||||
)
|
||||
check_argument("run_name", c, restricted=True, val_type=str)
|
||||
check_argument("run_description", c, val_type=str)
|
||||
# def check_config_tts(c):
|
||||
# check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
|
||||
# check_argument('run_name', c, restricted=True, val_type=str)
|
||||
# check_argument('run_description', c, val_type=str)
|
||||
|
||||
# AUDIO
|
||||
# check_argument('audio', c, restricted=True, val_type=dict)
|
||||
# # AUDIO
|
||||
# # check_argument('audio', c, restricted=True, val_type=dict)
|
||||
|
||||
# audio processing parameters
|
||||
# check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
# check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
# check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
# check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
# check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
# check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
# check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
# check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
# check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
# check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
# # audio processing parameters
|
||||
# # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
# # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
# # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
# # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
# # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
# # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
# # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
# # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
# # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
# # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# vocabulary parameters
|
||||
check_argument("characters", c, restricted=False, val_type=dict)
|
||||
check_argument(
|
||||
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"characters",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys(),
|
||||
val_type=str,
|
||||
)
|
||||
check_argument(
|
||||
"phonemes",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys() and c["use_phonemes"],
|
||||
val_type=str,
|
||||
)
|
||||
check_argument(
|
||||
"punctuations",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys(),
|
||||
val_type=str,
|
||||
)
|
||||
# # vocabulary parameters
|
||||
# check_argument('characters', c, restricted=False, val_type=dict)
|
||||
# check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
|
||||
# check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
|
||||
# normalization parameters
|
||||
# check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
# check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
# check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
# check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||
# check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
# check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
# # normalization parameters
|
||||
# # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
# # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
# # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
# # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||
# # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
# check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
# check_argument('mixed_precision', c, restricted=False, val_type=bool)
|
||||
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
# # training parameters
|
||||
# # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
# # check_argument('mixed_precision', c, restricted=False, val_type=bool)
|
||||
# # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# loss parameters
|
||||
# check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
# check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # loss parameters
|
||||
# # check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# # if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
# # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
# check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
|
||||
# validation parameters
|
||||
# check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
# check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
# # validation parameters
|
||||
# # check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
# # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
check_argument("noam_schedule", c, restricted=False, val_type=bool)
|
||||
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
|
||||
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
|
||||
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # optimizer
|
||||
# check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
# check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
# check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
|
||||
# check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
# check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
# check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||
# check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # tacotron prenet
|
||||
# # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
# # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||
# # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# attention
|
||||
check_argument(
|
||||
"attention_type",
|
||||
c,
|
||||
restricted=is_tacotron(c),
|
||||
val_type=str,
|
||||
enum_list=["graves", "original", "dynamic_convolution"],
|
||||
)
|
||||
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
|
||||
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
|
||||
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
# # attention
|
||||
# check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
|
||||
# check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||
# check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
# check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
if c["model"].lower() in ["tacotron", "tacotron2"]:
|
||||
# stopnet
|
||||
# check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
# # stopnet
|
||||
# # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# Model Parameters for non-tacotron models
|
||||
if c["model"].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument("positional_encoding", c, restricted=True, val_type=type)
|
||||
check_argument("encoder_type", c, restricted=True, val_type=str)
|
||||
check_argument("encoder_params", c, restricted=True, val_type=dict)
|
||||
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)
|
||||
# # Model Parameters for non-tacotron models
|
||||
# if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
# check_argument('positional_encoding', c, restricted=True, val_type=type)
|
||||
# check_argument('encoder_type', c, restricted=True, val_type=str)
|
||||
# check_argument('encoder_params', c, restricted=True, val_type=dict)
|
||||
# check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
|
||||
|
||||
# GlowTTS parameters
|
||||
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
|
||||
# # GlowTTS parameters
|
||||
# check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
|
||||
|
||||
# tensorboard
|
||||
# check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
# check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
# # tensorboard
|
||||
# # check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
# # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from TTS.tts.utils.text import cleaners
|
||||
# check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
# check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
# check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
# check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
|
||||
# # dataloading
|
||||
# # pylint: disable=import-outside-toplevel
|
||||
# from TTS.tts.utils.text import cleaners
|
||||
# # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
# # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
# # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
# # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
|
||||
|
||||
# paths
|
||||
# check_argument('output_path', c, restricted=True, val_type=str)
|
||||
# # paths
|
||||
# # check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker and gst
|
||||
# check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
# check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
|
||||
# check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||
# check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
# check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
# check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
# check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
# check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
# # multi-speaker and gst
|
||||
# # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
# # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
|
||||
# # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
|
||||
# if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||
# # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
# # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
# # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
# # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
# # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
# check_argument('datasets', c, restricted=True, val_type=list)
|
||||
# for dataset_entry in c['datasets']:
|
||||
# check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
# check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
# check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
# check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
# # datasets - checking only the first entry
|
||||
# # check_argument('datasets', c, restricted=True, val_type=list)
|
||||
# # for dataset_entry in c['datasets']:
|
||||
# # check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
# # check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
# # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
# # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
|
|
|
@ -8,12 +8,10 @@ import json
|
|||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
||||
|
||||
|
@ -140,11 +138,11 @@ def process_args(args, config, tb_prefix):
|
|||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# setup output paths and read configs
|
||||
c = config.load_json(args.config_path)
|
||||
if c.mixed_precision:
|
||||
config.load_json(args.config_path)
|
||||
if config.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
if not os.path.exists(c.output_path):
|
||||
out_path = create_experiment_folder(c.output_path, c.run_name,
|
||||
if not os.path.exists(config.output_path):
|
||||
out_path = create_experiment_folder(config.output_path, config.run_name,
|
||||
args.debug)
|
||||
audio_path = os.path.join(out_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
|
@ -157,7 +155,7 @@ def process_args(args, config, tb_prefix):
|
|||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if c.has('characters_config'):
|
||||
if config.has('characters_config'):
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(c, args.config_path, out_path, new_fields)
|
||||
|
@ -166,6 +164,6 @@ def process_args(args, config, tb_prefix):
|
|||
log_path = out_path
|
||||
tb_logger = TensorboardLogger(log_path, model_name=tb_prefix)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
||||
tb_logger.tb_add_text("model-description", config["run_description"], 0)
|
||||
c_logger = ConsoleLogger()
|
||||
return c, out_path, audio_path, c_logger, tb_logger
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -67,6 +71,20 @@ def count_parameters(model):
|
|||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
text = text.replace('Tts', 'TTS')
|
||||
return text
|
||||
|
||||
|
||||
def find_module(module_path: str, module_name: str) -> object:
|
||||
module_name = module_name.lower()
|
||||
module = importlib.import_module(module_path+'.'+module_name)
|
||||
class_name = to_camel(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_user_data_dir(appname):
|
||||
if sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
@ -139,32 +157,3 @@ class KeepAverage:
|
|||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def check_argument(name,
|
||||
c,
|
||||
prerequest=None,
|
||||
enum_list=None,
|
||||
max_val=None,
|
||||
min_val=None,
|
||||
restricted=False,
|
||||
alternative=None,
|
||||
allow_none=False):
|
||||
if isinstance(prerequest, List()):
|
||||
if any([f not in c.keys() for f in prerequest]):
|
||||
return
|
||||
else:
|
||||
if prerequest not in c.keys():
|
||||
return
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if allow_none and c[name] is None:
|
||||
return
|
||||
if restricted:
|
||||
assert name in c.keys(), f" [!] {name} not defined in config.json"
|
||||
if name in c.keys():
|
||||
if max_val:
|
||||
assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
|
||||
if min_val:
|
||||
assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
|
||||
if enum_list:
|
||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pickle as pickle_tts
|
||||
import re
|
||||
from shutil import copyfile
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -23,32 +24,37 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
# def read_json_with_comments(json_path):
|
||||
# # fallback to json
|
||||
# with open(json_path, "r", encoding="utf-8") as f:
|
||||
# input_str = f.read()
|
||||
# # handle comments
|
||||
# input_str = re.sub(r'\\\n', '', input_str)
|
||||
# input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
# data = json.loads(input_str)
|
||||
# return data
|
||||
def read_json_with_comments(json_path):
|
||||
"""DEPRECATED"""
|
||||
# fallback to json
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
data = json.loads(input_str)
|
||||
return data
|
||||
|
||||
# def load_config(config_path: str) -> AttrDict:
|
||||
# """Load config files and discard comments
|
||||
def load_config(config_path: str) -> AttrDict:
|
||||
"""DEPRECATED: Load config files and discard comments
|
||||
|
||||
# Args:
|
||||
# config_path (str): path to config file.
|
||||
# """
|
||||
# config = AttrDict()
|
||||
|
||||
# ext = os.path.splitext(config_path)[1]
|
||||
# # if ext in (".yml", ".yaml"):
|
||||
# # with open(config_path, "r", encoding="utf-8") as f:
|
||||
# # data = yaml.safe_load(f)
|
||||
# # else:
|
||||
# data = read_json_with_comments(config_path)
|
||||
# config.update(data)
|
||||
# return config
|
||||
Args:
|
||||
config_path (str): path to config file.
|
||||
"""
|
||||
config_dict = AttrDict()
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
else:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
data = json.loads(input_str)
|
||||
config_dict.update(data)
|
||||
config_class = find_module('TTS.tts.configs', config_dict.model.lower()+'_config')
|
||||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return
|
||||
|
||||
|
||||
def copy_model_files(c, config_file, out_path, new_fields):
|
||||
|
|
Loading…
Reference in New Issue