config refactor #5 WIP

This commit is contained in:
Eren Gölge 2021-04-02 14:24:12 +02:00
parent dc50f5f0b0
commit 79d7215142
6 changed files with 236 additions and 244 deletions

View File

@ -8,7 +8,6 @@ import os
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.config_manager import ConfigManager
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config from TTS.utils.io import load_config
@ -16,8 +15,6 @@ from TTS.utils.io import load_config
def main(): def main():
"""Run preprocessing process.""" """Run preprocessing process."""
CONFIG = ConfigManager()
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features.") description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str, parser.add_argument("config_path", type=str,
@ -26,17 +23,17 @@ def main():
help="save path (directory and filename).") help="save path (directory and filename).")
parser.add_argument("--data_path", type=str, required=False, parser.add_argument("--data_path", type=str, required=False,
help="folder including the target set of wavs overriding dataset config.") help="folder including the target set of wavs overriding dataset config.")
parser = CONFIG.init_argparse(parser) args, overrides = parser.parse_known_args()
args = parser.parse_args()
CONFIG.parse_argparse(args) CONFIG = load_config(args.config_path)
CONFIG.parse_args(overrides)
# load config # load config
CONFIG.load_config(args.config_path) CONFIG.audio.signal_norm = False # do not apply earlier normalization
CONFIG.audio_config.signal_norm = False # do not apply earlier normalization CONFIG.audio.stats_path = None # discard pre-defined stats
CONFIG.audio_config.stats_path = None # discard pre-defined stats
# load audio processor # load audio processor
ap = AudioProcessor(**CONFIG.audio_config.to_dict()) ap = AudioProcessor(**CONFIG.audio.to_dict())
# load the meta data of target dataset # load the meta data of target dataset
if args.data_path: if args.data_path:
@ -81,15 +78,14 @@ def main():
print(f" > Avg lienar spec scale: {linear_scale.mean()}") print(f" > Avg lienar spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling # set default config values for mean-var scaling
CONFIG.audio_config.stats_path = output_file_path CONFIG.audio.stats_path = output_file_path
CONFIG.audio_config.signal_norm = True CONFIG.audio.signal_norm = True
# remove redundant values # remove redundant values
del CONFIG.audio_config.max_norm del CONFIG.audio.max_norm
del CONFIG.audio_config.min_level_db del CONFIG.audio.min_level_db
del CONFIG.audio_config.symmetric_norm del CONFIG.audio.symmetric_norm
del CONFIG.audio_config.clip_norm del CONFIG.audio.clip_norm
breakpoint() stats['audio_config'] = CONFIG.audio.to_dict()
stats['audio_config'] = CONFIG.audio_config.to_dict()
np.save(output_file_path, stats, allow_pickle=True) np.save(output_file_path, stats, allow_pickle=True)
print(f" > stats saved to {output_file_path}") print(f" > stats saved to {output_file_path}")

View File

@ -20,9 +20,8 @@ from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.arguments import init_training
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.config_manager import ConfigManager
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
@ -41,47 +40,49 @@ use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
if is_val and not c.run_eval: if is_val and not config.run_eval:
loader = None loader = None
else: else:
if dataset is None: if dataset is None:
dataset = MyDataset( dataset = MyDataset(
r, r,
c.text_cleaner, config.text_cleaner,
compute_linear_spec=c.model.lower() == "tacotron", compute_linear_spec=config.model.lower() == 'tacotron',
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if "characters" in c.keys() else None, tp=config.characters,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False, add_blank=config['add_blank'],
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, batch_group_size=0 if is_val else config.batch_group_size *
min_seq_len=c.min_seq_len, config.batch_size,
max_seq_len=c.max_seq_len, min_seq_len=config.min_seq_len,
phoneme_cache_path=c.phoneme_cache_path, max_seq_len=config.max_seq_len,
use_phonemes=c.use_phonemes, phoneme_cache_path=config.phoneme_cache_path,
phoneme_language=c.phoneme_language, use_phonemes=config.use_phonemes,
enable_eos_bos=c.enable_eos_bos_chars, phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars,
verbose=verbose, verbose=verbose,
speaker_mapping=( speaker_mapping=(speaker_mapping if (
speaker_mapping if (c.use_speaker_embedding and c.use_external_speaker_embedding_file) else None config.use_speaker_embedding
), and config.use_external_speaker_embedding_file
) ) else None)
)
if c.use_phonemes and c.compute_input_seq_cache: if config.use_phonemes and config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers) dataset.compute_input_seq(config.num_loader_workers)
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size, batch_size=config.eval_batch_size if is_val else config.batch_size,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False,
sampler=sampler, sampler=sampler,
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, num_workers=config.num_val_loader_workers
pin_memory=False, if is_val else config.num_loader_workers,
) pin_memory=False)
return loader return loader
@ -90,15 +91,15 @@ def format_data(data):
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
speaker_names = data[2] speaker_names = data[2]
linear_input = data[3] if c.model.lower() in ["tacotron"] else None linear_input = data[3] if config.model in ["Tacotron"] else None
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
max_text_length = torch.max(text_lengths.float()) max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float()) max_spec_length = torch.max(mel_lengths.float())
if c.use_speaker_embedding: if config.use_speaker_embedding:
if c.use_external_speaker_embedding_file: if config.use_external_speaker_embedding_file:
speaker_embeddings = data[8] speaker_embeddings = data[8]
speaker_ids = None speaker_ids = None
else: else:
@ -110,8 +111,10 @@ def format_data(data):
speaker_ids = None speaker_ids = None
# set stop targets view, we predict a single stop token per iteration. # set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(text_input.shape[0],
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -119,7 +122,7 @@ def format_data(data):
text_lengths = text_lengths.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if c.model.lower() in ["tacotron"] else None linear_input = linear_input.cuda(non_blocking=True) if config.model.lower() in ["tacotron"] else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None: if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True) speaker_ids = speaker_ids.cuda(non_blocking=True)
@ -145,9 +148,10 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
if use_cuda: if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) batch_n_iter = int(
len(data_loader.dataset) / (config.batch_size * num_gpus))
else: else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
end_time = time.time() end_time = time.time()
c_logger.print_train_start() c_logger.print_train_start()
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
@ -171,31 +175,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
global_step += 1 global_step += 1
# setup lr # setup lr
if c.noam_schedule: if config.noam_schedule:
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
if optimizer_st: if optimizer_st:
optimizer_st.zero_grad() optimizer_st.zero_grad()
with torch.cuda.amp.autocast(enabled=c.mixed_precision): with torch.cuda.amp.autocast(enabled=config.mixed_precision):
# forward pass model # forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
( decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
decoder_output, text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_input,
@ -237,18 +228,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
raise RuntimeError(f"Detected NaN loss at step {global_step}.") raise RuntimeError(f"Detected NaN loss at step {global_step}.")
# optimizer step # optimizer step
if c.mixed_precision: if config.mixed_precision:
# model optimizer step in mixed precision mode # model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward() scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
# stopnet optimizer step # stopnet optimizer step
if c.separate_stopnet: if config.separate_stopnet:
scaler_st.scale(loss_dict["stopnet_loss"]).backward() scaler_st.scale(loss_dict['stopnet_loss']).backward()
scaler.unscale_(optimizer_st) scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
@ -260,12 +251,12 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# main model optimizer step # main model optimizer step
loss_dict["loss"].backward() loss_dict["loss"].backward()
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True)
optimizer.step() optimizer.step()
# stopnet optimizer step # stopnet optimizer step
if c.separate_stopnet: if config.separate_stopnet:
loss_dict["stopnet_loss"].backward() loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step() optimizer_st.step()
@ -281,12 +272,10 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict["stopnet_loss"] = ( loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if config.stopnet else loss_dict['stopnet_loss']
reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"]
)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -306,7 +295,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
# print training progress # print training progress
if global_step % c.print_step == 0: if global_step % config.print_step == 0:
log_dict = { log_dict = {
"max_spec_length": [max_spec_length, 1], # value, precision "max_spec_length": [max_spec_length, 1], # value, precision
"max_text_length": [max_text_length, 1], "max_text_length": [max_text_length, 1],
@ -319,7 +308,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
if args.rank == 0: if args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
# reduce TB load # reduce TB load
if global_step % c.tb_plot_step == 0: if global_step % config.tb_plot_step == 0:
iter_stats = { iter_stats = {
"lr": current_lr, "lr": current_lr,
"grad_norm": grad_norm, "grad_norm": grad_norm,
@ -329,29 +318,20 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0: if global_step % config.save_step == 0:
if c.checkpoint: if config.checkpoint:
# save model # save model
save_checkpoint( save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
model, optimizer_st=optimizer_st,
optimizer, model_loss=loss_dict['postnet_loss'],
global_step, characters=model_characters,
epoch, scaler=scaler.state_dict() if config.mixed_precision else None)
model.decoder.r,
OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict["postnet_loss"],
characters=model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None,
)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = ( gt_spec = linear_input[0].data.cpu().numpy() if config.model in [
linear_input[0].data.cpu().numpy() "Tacotron", "TacotronGST"
if c.model in ["Tacotron", "TacotronGST"] ] else mel_input[0].data.cpu().numpy()
else mel_input[0].data.cpu().numpy()
)
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
@ -360,19 +340,19 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
"alignment": plot_alignment(align_img, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False),
} }
if c.bidirectional_decoder or c.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment( figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
alignments_backward[0].data.cpu().numpy(), output_fig=False
)
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
if c.model in ["Tacotron", "TacotronGST"]: if config.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_spec.T) train_audio = ap.inv_spectrogram(const_speconfig.T)
else: else:
train_audio = ap.inv_melspectrogram(const_spec.T) train_audio = ap.inv_melspectrogram(const_speconfig.T)
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
config.audio["sample_rate"])
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
@ -383,7 +363,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
epoch_stats = {"epoch_time": epoch_time} epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values) epoch_stats.update(keep_avg.avg_values)
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if c.tb_model_param_stats: if config.tb_model_param_stats:
tb_logger.tb_model_weights(model, global_step) tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step return keep_avg.avg_values, global_step
@ -414,17 +394,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
assert mel_input.shape[1] % model.decoder.r == 0 assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model # forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
( decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
decoder_output, text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
@ -466,10 +438,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
if c.stopnet: if config.stopnet:
loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -486,18 +458,16 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
update_train_values["avg_" + key] = value update_train_values["avg_" + key] = value
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
if c.print_eval: if config.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0: if args.rank == 0:
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy() const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = ( gt_spec = linear_input[idx].data.cpu().numpy() if config.model in [
linear_input[idx].data.cpu().numpy() "Tacotron", "TacotronGST"
if c.model in ["Tacotron", "TacotronGST"] ] else mel_input[idx].data.cpu().numpy()
else mel_input[idx].data.cpu().numpy()
)
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
eval_figures = { eval_figures = {
@ -507,22 +477,23 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
} }
# Sample audio # Sample audio
if c.model in ["Tacotron", "TacotronGST"]: if config.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_spec.T) eval_audio = ap.inv_spectrogram(const_speconfig.T)
else: else:
eval_audio = ap.inv_melspectrogram(const_spec.T) eval_audio = ap.inv_melspectrogram(const_speconfig.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
config.audio["sample_rate"])
# Plot Validation Stats # Plot Validation Stats
if c.bidirectional_decoder or c.double_decoder_consistency: if config.bidirectional_decoder or config.double_decoder_consistency:
align_b_img = alignments_backward[idx].data.cpu().numpy() align_b_img = alignments_backward[idx].data.cpu().numpy()
eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False) eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_eval_figures(global_step, eval_figures) tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs: if args.rank == 0 and epoch > config.test_delay_epochs:
if c.test_sentences_file is None: if config.test_sentences_file is None:
test_sentences = [ test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.", "Be a voice, not an echo.",
@ -531,40 +502,36 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"Prior to November 22, 1963.", "Prior to November 22, 1963.",
] ]
else: else:
with open(c.test_sentences_file, "r") as f: with open(config.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()] test_sentences = [s.strip() for s in f.readlines()]
# test sentences # test sentences
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None speaker_id = 0 if config.use_speaker_embedding else None
speaker_embedding = ( speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if config.use_external_speaker_embedding_file and config.use_speaker_embedding else None
speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"] style_wav = config.get("gst_style_input")
if c.use_external_speaker_embedding_file and c.use_speaker_embedding if style_wav is None and config.use_gst:
else None
)
style_wav = c.get("gst_style_input")
if style_wav is None and c.use_gst:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(c.gst['gst_num_style_tokens']): for i in range(config.gst['gst_num_style_tokens']):
style_wav[str(i)] = 0 style_wav[str(i)] = 0
style_wav = c.get("gst_style_input", style_wav) style_wav = config.get("gst_style_input")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
model, model,
test_sentence, test_sentence,
c, config,
use_cuda, use_cuda,
ap, ap,
speaker_id=speaker_id, speaker_id=speaker_id,
speaker_embedding=speaker_embedding, speaker_embedding=speaker_embedding,
style_wav=style_wav, style_wav=style_wav,
truncated=False, truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument enable_eos_bos_chars=config.enable_eos_bos_chars, #pylint: disable=unused-argument
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
) )
@ -579,7 +546,8 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_audios(global_step, test_audios,
config.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures) tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values return keep_avg.avg_values
@ -588,45 +556,48 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**config.audio.to_dict())
# setup custom characters if set in config file. # setup custom characters if set in config file.
if "characters" in c.keys(): if config.characters is not None:
symbols, phonemes = make_symbols(**c.characters) symbols, phonemes = make_symbols(**config.characters.to_dict())
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) init_distributed(args.rank, num_gpus, args.group_id,
num_chars = len(phonemes) if c.use_phonemes else len(symbols) config.distributed["backend"], config.distributed["url"])
model_characters = phonemes if c.use_phonemes else symbols num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model_characters = phonemes if config.use_phonemes else symbols
# load data instances # load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets) meta_data_train, meta_data_eval = load_meta_data(config.datasets)
# set the portion of the data used for training # set the portion of the data used for training
if "train_portion" in c.keys(): if config.has('train_portion'):
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] meta_data_train = meta_data_train[:int(len(meta_data_train) * config.train_portion)]
if "eval_portion" in c.keys(): if config.has('eval_portion'):
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * config.eval_portion)]
# parse speakers # parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim)
# scalers for mixed precision training # scalers for mixed precision training
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None scaler_st = torch.cuda.amp.GradScaler() if config.mixed_precision and config.separate_stopnet else None
params = set_weight_decay(model, c.wd) params = set_weight_decay(model, config.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0) optimizer = RAdam(params, lr=config.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet: if config.stopnet and config.separate_stopnet:
optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) optimizer_st = RAdam(model.decoder.stopnet.parameters(),
lr=config.lr,
weight_decay=0)
else: else:
optimizer_st = None optimizer_st = None
# setup criterion # setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4) criterion = TacotronLoss(config, stopnet_pos_weight=config.stopnet_pos_weight, ga_sigma=0.4)
if args.restore_path: if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...") print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location="cpu") checkpoint = torch.load(args.restore_path, map_location="cpu")
@ -635,11 +606,11 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
# optimizer restore # optimizer restore
print(" > Restoring Optimizer...") print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint['optimizer'])
if "scaler" in checkpoint and c.mixed_precision: if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...") print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"]) scaler.load_state_dict(checkpoint["scaler"])
if c.reinit_layers: if config.reinit_layers:
raise RuntimeError raise RuntimeError
except (KeyError, RuntimeError): except (KeyError, RuntimeError):
print(" > Partial model initialization...") print(" > Partial model initialization...")
@ -651,9 +622,10 @@ def main(args): # pylint: disable=redefined-outer-name
del model_dict del model_dict
for group in optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = c.lr group['lr'] = config.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True) print(" > Model restored from step %d" % checkpoint['step'],
args.restore_step = checkpoint["step"] flush=True)
args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -665,8 +637,10 @@ def main(args): # pylint: disable=redefined-outer-name
if num_gpus > 1: if num_gpus > 1:
model = apply_gradient_allreduce(model) model = apply_gradient_allreduce(model)
if c.noam_schedule: if config.noam_schedule:
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) scheduler = NoamLR(optimizer,
warmup_steps=config.warmup_steps,
last_epoch=args.restore_step - 1)
else: else:
scheduler = None scheduler = None
@ -680,22 +654,22 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.") print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get("keep_all_best", False) keep_all_best = config.keep_all_best
keep_after = c.get("keep_after", 10000) # void if keep_all_best False keep_after = config.keep_after # void if keep_all_best False
# define data loaders # define data loaders
train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True) train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True)
eval_loader = setup_loader(ap, model.decoder.r, is_val=True) eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
global_step = args.restore_step global_step = args.restore_step
for epoch in range(0, c.epochs): for epoch in range(0, config.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, config.epochs)
# set gradual training # set gradual training
if c.gradual_training is not None: if config.gradual_training is not None:
r, c.batch_size = gradual_training_scheduler(global_step, c) r, config.batch_size = gradual_training_scheduler(global_step, c)
c.r = r config.r = r
model.decoder.set_r(r) model.decoder.set_r(r)
if c.bidirectional_decoder: if config.bidirectional_decoder:
model.decoder_backward.set_r(r) model.decoder_backward.set_r(r)
train_loader.dataset.outputs_per_step = r train_loader.dataset.outputs_per_step = r
eval_loader.dataset.outputs_per_step = r eval_loader.dataset.outputs_per_step = r
@ -719,9 +693,9 @@ def main(args): # pylint: disable=redefined-outer-name
# eval one epoch # eval one epoch
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict["avg_postnet_loss"] target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval: if config.run_eval:
target_loss = eval_avg_loss_dict["avg_postnet_loss"] target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model( best_loss = save_best_model(
target_loss, target_loss,
best_loss, best_loss,
@ -729,31 +703,26 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer, optimizer,
global_step, global_step,
epoch, epoch,
c.r, config.r,
OUT_PATH, OUT_PATH,
model_characters, model_characters,
keep_all_best=keep_all_best, keep_all_best=keep_all_best,
keep_after=keep_after, keep_after=keep_after,
scaler=scaler.state_dict() if c.mixed_precision else None, scaler=scaler.state_dict() if config.mixed_precision else None
) )
if __name__ == "__main__": if __name__ == '__main__':
args = parse_arguments(sys.argv) args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
c = TacotronConfig()
args = c.init_argparse(args)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, c, model_type='tacotron')
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH) # remove_experiment_folder(OUT_PATH)
try: try:
sys.exit(0) sys.exit(0)
except SystemExit: except SystemExit:
os._exit(0) # pylint: disable=protected-access os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH) # remove_experiment_folder(OUT_PATH)
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View File

@ -37,8 +37,8 @@ def load_meta_data(datasets, eval_split=True):
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train meta_data_train_all += meta_data_train
# load attention masks for duration predictor training # load attention masks for duration predictor training
if "meta_file_attn_mask" in dataset and dataset["meta_file_attn_mask"] is not None: if dataset.meta_file_attn_mask is not None:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask']))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].append(attn_file)

View File

@ -38,7 +38,7 @@ def sequence_mask(sequence_length, max_len=None):
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model)) print(" > Using model: {}".format(c.model))
find_module("TTS.tts.models", c.model.lower()) MyModel = find_module("TTS.tts.models", c.model.lower())
if c.model.lower() in "tacotron": if c.model.lower() in "tacotron":
model = MyModel( model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False), num_chars=num_chars + getattr(c, "add_blank", False),
@ -76,11 +76,11 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r, r=c.r,
postnet_output_dim=c.audio["num_mels"], postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst, gst=c.gst is not None,
gst_embedding_dim=c.gst["gst_embedding_dim"], gst_embedding_dim=None if c.gst is None else c.gst['gst_embedding_dim'],
gst_num_heads=c.gst["gst_num_heads"], gst_num_heads=None if c.gst is None else c.gst['gst_num_heads'],
gst_style_tokens=c.gst["gst_style_tokens"], gst_num_style_tokens=None if c.gst is None else c.gst['gst_num_style_tokens'],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], gst_use_speaker_embedding=None if c.gst is None else c.gst['gst_use_speaker_embedding'],
attn_type=c.attention_type, attn_type=c.attention_type,
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,

View File

@ -6,16 +6,17 @@ import argparse
import glob import glob
import json import json
import os import os
import sys
import re import re
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
def parse_arguments(argv): def init_arguments(argv):
"""Parse command line arguments of training scripts. """Parse command line arguments of training scripts.
Args: Args:
@ -45,16 +46,26 @@ def parse_arguments(argv):
"Best model file to be used for extracting best loss." "Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used" "If not specified, the latest best model in continue path is used"
), ),
default="", default="")
) parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument( parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv "--rank",
) type=int,
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.") default=0,
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.") help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.") parser.add_argument("--group_id",
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser.parse_args() return parser
def get_last_checkpoint(path): def get_last_checkpoint(path):
@ -115,7 +126,7 @@ def get_last_checkpoint(path):
return last_models["checkpoint"], last_models["best_model"] return last_models["checkpoint"], last_models["best_model"]
def process_args(args, config, tb_prefix): def process_args(args):
"""Process parsed comand line arguments. """Process parsed comand line arguments.
Args: Args:
@ -130,21 +141,27 @@ def process_args(args, config, tb_prefix):
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard loggind. the TensorBoard loggind.
""" """
if isinstance(args, tuple):
args, coqpit_overrides = args
if args.continue_path: if args.continue_path:
# continue a previous training from its output folder # continue a previous training from its output folder
args.output_path = args.continue_path experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json") args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path) args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path: if not args.best_path:
args.best_path = best_model args.best_path = best_model
# setup output paths and read configs # setup output paths and read configs
config.load_json(args.config_path) config = load_config(args.config_path)
# override values from command-line args
config.parse_args(coqpit_overrides)
if config.mixed_precision: if config.mixed_precision:
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
if not os.path.exists(config.output_path): if not os.path.exists(config.output_path):
out_path = create_experiment_folder(config.output_path, config.run_name, experiment_path = create_experiment_folder(config.output_path,
args.debug) config.run_name, args.debug)
audio_path = os.path.join(out_path, "test_audios") else:
experiment_path = config.output_path
audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
if args.rank == 0: if args.rank == 0:
os.makedirs(audio_path, exist_ok=True) os.makedirs(audio_path, exist_ok=True)
@ -157,13 +174,22 @@ def process_args(args, config, tb_prefix):
# compatibility. # compatibility.
if config.has('characters_config'): if config.has('characters_config'):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields['characters'] = used_characters
copy_model_files(c, args.config_path, out_path, new_fields) copy_model_files(config, args.config_path, experiment_path, new_fields)
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)
os.chmod(out_path, 0o775) os.chmod(experiment_path, 0o775)
log_path = out_path tb_logger = TensorboardLogger(experiment_path,
tb_logger = TensorboardLogger(log_path, model_name=tb_prefix) model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", config["run_description"], 0) tb_logger.tb_add_text("model-description", config["run_description"],
0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return c, out_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger
def init_training(argv):
"""Initialization of a training run."""
parser = init_arguments(argv)
args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger

View File

@ -3,9 +3,11 @@ import os
import pickle as pickle_tts import pickle as pickle_tts
import re import re
from shutil import copyfile from shutil import copyfile
from TTS.utils.generic_utils import find_module
import yaml import yaml
from TTS.utils.generic_utils import find_module
from .generic_utils import find_module
class RenamingUnpickler(pickle_tts.Unpickler): class RenamingUnpickler(pickle_tts.Unpickler):
@ -35,26 +37,25 @@ def read_json_with_comments(json_path):
data = json.loads(input_str) data = json.loads(input_str)
return data return data
def load_config(config_path: str) -> AttrDict:
"""DEPRECATED: Load config files and discard comments
Args: def load_config(config_path: str) -> None:
config_path (str): path to config file. config_dict = {}
"""
config_dict = AttrDict()
ext = os.path.splitext(config_path)[1] ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"): if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
else: elif ext == '.json':
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
data = json.loads(input_str) data = json.loads(input_str)
else:
raise TypeError(f' [!] Unknown config file type {ext}')
config_dict.update(data) config_dict.update(data)
config_class = find_module('TTS.tts.configs', config_dict.model.lower()+'_config') config_class = find_module('TTS.tts.configs', config_dict['model'].lower()+'_config')
config = config_class() config = config_class()
config.from_dict(config_dict) config.from_dict(config_dict)
return return config
def copy_model_files(c, config_file, out_path, new_fields): def copy_model_files(c, config_file, out_path, new_fields):