black formatting

This commit is contained in:
Eren Gölge 2021-05-03 16:42:15 +02:00
parent c34c8137d7
commit 9c18e40f64
4 changed files with 126 additions and 114 deletions

View File

@ -10,6 +10,7 @@ from random import randrange
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
@ -22,10 +23,8 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import init_training from TTS.utils.arguments import init_training
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
init_distributed, reduce_tensor) from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import ( from TTS.utils.training import (
NoamLR, NoamLR,
@ -47,13 +46,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
dataset = MyDataset( dataset = MyDataset(
r, r,
config.text_cleaner, config.text_cleaner,
compute_linear_spec=config.model.lower() == 'tacotron', compute_linear_spec=config.model.lower() == "tacotron",
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=config.characters, tp=config.characters,
add_blank=config['add_blank'], add_blank=config["add_blank"],
batch_group_size=0 if is_val else config.batch_group_size * batch_group_size=0 if is_val else config.batch_group_size * config.batch_size,
config.batch_size,
min_seq_len=config.min_seq_len, min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
@ -61,11 +59,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
phoneme_language=config.phoneme_language, phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars, enable_eos_bos=config.enable_eos_bos_chars,
verbose=verbose, verbose=verbose,
speaker_mapping=(speaker_mapping if ( speaker_mapping=(
config.use_speaker_embedding speaker_mapping
and config.use_external_speaker_embedding_file if (config.use_speaker_embedding and config.use_external_speaker_embedding_file)
) else None) else None
) ),
)
if config.use_phonemes and config.compute_input_seq_cache: if config.use_phonemes and config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
@ -80,9 +79,9 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False,
sampler=sampler, sampler=sampler,
num_workers=config.num_val_loader_workers num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers,
if is_val else config.num_loader_workers, pin_memory=False,
pin_memory=False) )
return loader return loader
@ -111,10 +110,8 @@ def format_data(data):
speaker_ids = None speaker_ids = None
# set stop targets view, we predict a single stop token per iteration. # set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets.size(1) // config.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -148,8 +145,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
if use_cuda: if use_cuda:
batch_n_iter = int( batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus))
len(data_loader.dataset) / (config.batch_size * num_gpus))
else: else:
batch_n_iter = int(len(data_loader.dataset) / config.batch_size) batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
end_time = time.time() end_time = time.time()
@ -185,8 +181,21 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
with torch.cuda.amp.autocast(enabled=config.mixed_precision): with torch.cuda.amp.autocast(enabled=config.mixed_precision):
# forward pass model # forward pass model
if config.bidirectional_decoder or config.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( (
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_output,
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_input,
@ -239,7 +248,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# stopnet optimizer step # stopnet optimizer step
if config.separate_stopnet: if config.separate_stopnet:
scaler_st.scale(loss_dict['stopnet_loss']).backward() scaler_st.scale(loss_dict["stopnet_loss"]).backward()
scaler.unscale_(optimizer_st) scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
@ -256,7 +265,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# stopnet optimizer step # stopnet optimizer step
if config.separate_stopnet: if config.separate_stopnet:
loss_dict['stopnet_loss'].backward() loss_dict["stopnet_loss"].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step() optimizer_st.step()
@ -272,10 +281,12 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if config.stopnet else loss_dict['stopnet_loss'] loss_dict["stopnet_loss"] = (
reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if config.stopnet else loss_dict["stopnet_loss"]
)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -321,17 +332,26 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
if global_step % config.save_step == 0: if global_step % config.save_step == 0:
if config.checkpoint: if config.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, save_checkpoint(
optimizer_st=optimizer_st, model,
model_loss=loss_dict['postnet_loss'], optimizer,
characters=model_characters, global_step,
scaler=scaler.state_dict() if config.mixed_precision else None) epoch,
model.decoder.r,
OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict["postnet_loss"],
characters=model_characters,
scaler=scaler.state_dict() if config.mixed_precision else None,
)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if config.model in [ gt_spec = (
"Tacotron", "TacotronGST" linear_input[0].data.cpu().numpy()
] else mel_input[0].data.cpu().numpy() if config.model in ["Tacotron", "TacotronGST"]
else mel_input[0].data.cpu().numpy()
)
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
@ -341,7 +361,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
} }
if config.bidirectional_decoder or config.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) figures["alignment_backward"] = plot_alignment(
alignments_backward[0].data.cpu().numpy(), output_fig=False
)
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
@ -350,9 +372,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
train_audio = ap.inv_spectrogram(const_speconfig.T) train_audio = ap.inv_spectrogram(const_speconfig.T)
else: else:
train_audio = ap.inv_melspectrogram(const_speconfig.T) train_audio = ap.inv_melspectrogram(const_speconfig.T)
tb_logger.tb_train_audios(global_step, tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
{'TrainAudio': train_audio},
config.audio["sample_rate"])
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
@ -395,8 +415,16 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# forward pass model # forward pass model
if config.bidirectional_decoder or config.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( (
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_output,
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
@ -438,10 +466,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
if config.stopnet: if config.stopnet:
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -465,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy() const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if config.model in [ gt_spec = (
"Tacotron", "TacotronGST" linear_input[idx].data.cpu().numpy()
] else mel_input[idx].data.cpu().numpy() if config.model in ["Tacotron", "TacotronGST"]
else mel_input[idx].data.cpu().numpy()
)
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
eval_figures = { eval_figures = {
@ -481,8 +511,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
eval_audio = ap.inv_spectrogram(const_speconfig.T) eval_audio = ap.inv_spectrogram(const_speconfig.T)
else: else:
eval_audio = ap.inv_melspectrogram(const_speconfig.T) eval_audio = ap.inv_melspectrogram(const_speconfig.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
config.audio["sample_rate"])
# Plot Validation Stats # Plot Validation Stats
@ -510,13 +539,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
test_figures = {} test_figures = {}
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
speaker_id = 0 if config.use_speaker_embedding else None speaker_id = 0 if config.use_speaker_embedding else None
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if config.use_external_speaker_embedding_file and config.use_speaker_embedding else None speaker_embedding = (
speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"]
if config.use_external_speaker_embedding_file and config.use_speaker_embedding
else None
)
style_wav = config.get("gst_style_input") style_wav = config.get("gst_style_input")
if style_wav is None and config.use_gst: if style_wav is None and config.use_gst:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(config.gst['gst_num_style_tokens']): for i in range(config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0 style_wav[str(i)] = 0
style_wav = config.get("gst_style_input") style_wav = config.get("gst_style_input")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
@ -531,7 +564,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
speaker_embedding=speaker_embedding, speaker_embedding=speaker_embedding,
style_wav=style_wav, style_wav=style_wav,
truncated=False, truncated=False,
enable_eos_bos_chars=config.enable_eos_bos_chars, #pylint: disable=unused-argument enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
) )
@ -546,8 +579,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios, tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"])
config.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures) tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values return keep_avg.avg_values
@ -564,8 +596,7 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"])
config.distributed["backend"], config.distributed["url"])
num_chars = len(phonemes) if config.use_phonemes else len(symbols) num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model_characters = phonemes if config.use_phonemes else symbols model_characters = phonemes if config.use_phonemes else symbols
@ -573,10 +604,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(config.datasets) meta_data_train, meta_data_eval = load_meta_data(config.datasets)
# set the portion of the data used for training # set the portion of the data used for training
if config.has('train_portion'): if config.has("train_portion"):
meta_data_train = meta_data_train[:int(len(meta_data_train) * config.train_portion)] meta_data_train = meta_data_train[: int(len(meta_data_train) * config.train_portion)]
if config.has('eval_portion'): if config.has("eval_portion"):
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * config.eval_portion)] meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * config.eval_portion)]
# parse speakers # parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH) num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
@ -590,9 +621,7 @@ def main(args): # pylint: disable=redefined-outer-name
params = set_weight_decay(model, config.wd) params = set_weight_decay(model, config.wd)
optimizer = RAdam(params, lr=config.lr, weight_decay=0) optimizer = RAdam(params, lr=config.lr, weight_decay=0)
if config.stopnet and config.separate_stopnet: if config.stopnet and config.separate_stopnet:
optimizer_st = RAdam(model.decoder.stopnet.parameters(), optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=config.lr, weight_decay=0)
lr=config.lr,
weight_decay=0)
else: else:
optimizer_st = None optimizer_st = None
@ -606,7 +635,7 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
# optimizer restore # optimizer restore
print(" > Restoring Optimizer...") print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision: if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...") print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"]) scaler.load_state_dict(checkpoint["scaler"])
@ -622,10 +651,9 @@ def main(args): # pylint: disable=redefined-outer-name
del model_dict del model_dict
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['lr'] = config.lr group["lr"] = config.lr
print(" > Model restored from step %d" % checkpoint['step'], print(" > Model restored from step %d" % checkpoint["step"], flush=True)
flush=True) args.restore_step = checkpoint["step"]
args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -638,9 +666,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = apply_gradient_allreduce(model) model = apply_gradient_allreduce(model)
if config.noam_schedule: if config.noam_schedule:
scheduler = NoamLR(optimizer, scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1)
warmup_steps=config.warmup_steps,
last_epoch=args.restore_step - 1)
else: else:
scheduler = None scheduler = None
@ -693,9 +719,9 @@ def main(args): # pylint: disable=redefined-outer-name
# eval one epoch # eval one epoch
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss'] target_loss = train_avg_loss_dict["avg_postnet_loss"]
if config.run_eval: if config.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss'] target_loss = eval_avg_loss_dict["avg_postnet_loss"]
best_loss = save_best_model( best_loss = save_best_model(
target_loss, target_loss,
best_loss, best_loss,
@ -708,11 +734,11 @@ def main(args): # pylint: disable=redefined-outer-name
model_characters, model_characters,
keep_all_best=keep_all_best, keep_all_best=keep_all_best,
keep_after=keep_after, keep_after=keep_after,
scaler=scaler.state_dict() if config.mixed_precision else None scaler=scaler.state_dict() if config.mixed_precision else None,
) )
if __name__ == '__main__': if __name__ == "__main__":
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
try: try:
main(args) main(args)

View File

@ -6,8 +6,8 @@ import argparse
import glob import glob
import json import json
import os import os
import sys
import re import re
import sys
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
@ -46,24 +46,14 @@ def init_arguments(argv):
"Best model file to be used for extracting best loss." "Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used" "If not specified, the latest best model in continue path is used"
), ),
default="") default="",
parser.add_argument("--config_path", )
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument( parser.add_argument(
"--rank", "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
type=int, )
default=0, parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
help="DISTRIBUTED: process rank for distributed training.") parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser return parser
@ -157,8 +147,7 @@ def process_args(args):
if config.mixed_precision: if config.mixed_precision:
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
if not os.path.exists(config.output_path): if not os.path.exists(config.output_path):
experiment_path = create_experiment_folder(config.output_path, experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
config.run_name, args.debug)
else: else:
experiment_path = config.output_path experiment_path = config.output_path
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
@ -172,17 +161,15 @@ def process_args(args):
# if model characters are not set in the config file # if model characters are not set in the config file
# save the default set to the config file for future # save the default set to the config file for future
# compatibility. # compatibility.
if config.has('characters_config'): if config.has("characters_config"):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields['characters'] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, args.config_path, experiment_path, new_fields) copy_model_files(config, args.config_path, experiment_path, new_fields)
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", config["run_description"], tb_logger.tb_add_text("model-description", config["run_description"], 0)
0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger

View File

@ -73,14 +73,14 @@ def count_parameters(model):
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace('Tts', 'TTS') text = text.replace("Tts", "TTS")
return text return text
def find_module(module_path: str, module_name: str) -> object: def find_module(module_path: str, module_name: str) -> object:
module_name = module_name.lower() module_name = module_name.lower()
module = importlib.import_module(module_path+'.'+module_name) module = importlib.import_module(module_path + "." + module_name)
class_name = to_camel(module_name) class_name = to_camel(module_name)
return getattr(module, class_name) return getattr(module, class_name)
@ -156,4 +156,3 @@ class KeepAverage:
def update_values(self, value_dict): def update_values(self, value_dict):
for key, value in value_dict.items(): for key, value in value_dict.items():
self.update_value(key, value) self.update_value(key, value)

View File

@ -5,6 +5,7 @@ import re
from shutil import copyfile from shutil import copyfile
import yaml import yaml
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
from .generic_utils import find_module from .generic_utils import find_module
@ -32,8 +33,8 @@ def read_json_with_comments(json_path):
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
# handle comments # handle comments
input_str = re.sub(r'\\\n', '', input_str) input_str = re.sub(r"\\\n", "", input_str)
input_str = re.sub(r'//.*\n', '\n', input_str) input_str = re.sub(r"//.*\n", "\n", input_str)
data = json.loads(input_str) data = json.loads(input_str)
return data return data
@ -44,20 +45,19 @@ def load_config(config_path: str) -> None:
if ext in (".yml", ".yaml"): if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
elif ext == '.json': elif ext == ".json":
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
data = json.loads(input_str) data = json.loads(input_str)
else: else:
raise TypeError(f' [!] Unknown config file type {ext}') raise TypeError(f" [!] Unknown config file type {ext}")
config_dict.update(data) config_dict.update(data)
config_class = find_module('TTS.tts.configs', config_dict['model'].lower()+'_config') config_class = find_module("TTS.tts.configs", config_dict["model"].lower() + "_config")
config = config_class() config = config_class()
config.from_dict(config_dict) config.from_dict(config_dict)
return config return config
def copy_model_files(c, config_file, out_path, new_fields): def copy_model_files(c, config_file, out_path, new_fields):
"""Copy config.json and other model files to training folder and add """Copy config.json and other model files to training folder and add
new fields. new fields.