From f3ff5b19712eb88d09a03d0a896e8f9976e770f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:44:02 +0200 Subject: [PATCH] Update `TTS.bin` scripts for the new API --- TTS/bin/compute_attention_masks.py | 2 +- TTS/bin/compute_statistics.py | 2 +- TTS/bin/convert_tacotron2_torch_to_tf.py | 6 +++--- TTS/bin/distribute.py | 21 +++++---------------- TTS/bin/extract_tts_spectrograms.py | 16 ++++------------ 5 files changed, 14 insertions(+), 33 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index eb708040..35721f59 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -75,7 +75,7 @@ Example run: # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker - model = setup_model(num_chars, num_speakers=0, c=C) + model = setup_model(C) model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) model.eval() diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 25e3fce5..6179dafc 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -77,7 +77,7 @@ def main(): print(f" > Avg mel spec mean: {mel_mean.mean()}") print(f" > Avg mel spec scale: {mel_scale.mean()}") print(f" > Avg linear spec mean: {linear_mean.mean()}") - print(f" > Avg lienar spec scale: {linear_scale.mean()}") + print(f" > Avg linear spec scale: {linear_scale.mean()}") # set default config values for mean-var scaling CONFIG.audio.stats_path = output_file_path diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index 119529ae..a6fb5d9b 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -31,18 +31,18 @@ c = load_config(config_path) num_speakers = 0 # init torch model -num_chars = len(phonemes) if c.use_phonemes else len(symbols) -model = setup_model(num_chars, num_speakers, c) +model = setup_model(c) checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) state_dict = checkpoint["model"] model.load_state_dict(state_dict) # init tf model +num_chars = len(phonemes) if c.use_phonemes else len(symbols) model_tf = Tacotron2( num_chars=num_chars, num_speakers=num_speakers, r=model.decoder.r, - postnet_output_dim=c.audio["num_mels"], + out_channels=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"], attn_type=c.attention_type, attn_win=c.windowing, diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 20d4bb20..873ddb1f 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -1,36 +1,24 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import argparse import os import pathlib import subprocess -import sys import time import torch +from TTS.trainer import TrainingArgs + def main(): """ Call train.py as a new process and pass command arguments """ - parser = argparse.ArgumentParser() + parser = TrainingArgs().init_argparse(arg_prefix="") parser.add_argument("--script", type=str, help="Target training script to distibute.") - parser.add_argument( - "--continue_path", - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default="", - required="--config_path" not in sys.argv, - ) - parser.add_argument( - "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" - ) - parser.add_argument( - "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv - ) args, unargs = parser.parse_known_args() + breakpoint() num_gpus = torch.cuda.device_count() group_id = time.strftime("%Y_%m_%d-%H%M%S") @@ -51,6 +39,7 @@ def main(): my_env = os.environ.copy() my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) command[-1] = "--rank={}".format(i) + # prevent stdout for processes with rank != 0 stdout = None if i == 0 else open(os.devnull, "w") p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with processes.append(p) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index c5ba1b2a..11cdfe31 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -14,7 +14,6 @@ from TTS.tts.datasets import load_meta_data from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model from TTS.tts.utils.speakers import get_speaker_manager -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters @@ -40,9 +39,7 @@ def setup_loader(ap, r, verbose=False): use_noise_augment=False, verbose=verbose, speaker_id_mapping=speaker_manager.speaker_ids, - d_vector_mapping=speaker_manager.d_vectors - if c.use_speaker_embedding and c.use_external_speaker_embedding_file - else None, + d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None, ) if c.use_phonemes and c.compute_input_seq_cache: @@ -224,16 +221,10 @@ def extract_spectrograms( def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data, symbols, phonemes, model_characters, speaker_manager + global meta_data, speaker_manager # Audio processor ap = AudioProcessor(**c.audio) - if "characters" in c.keys() and c["characters"]: - symbols, phonemes = make_symbols(**c.characters) - - # set model characters - model_characters = phonemes if c.use_phonemes else symbols - num_chars = len(model_characters) # load data instances meta_data_train, meta_data_eval = load_meta_data(c.datasets) @@ -245,7 +236,7 @@ def main(args): # pylint: disable=redefined-outer-name speaker_manager = get_speaker_manager(c, args, meta_data_train) # setup model - model = setup_model(num_chars, speaker_manager.num_speakers, c, d_vector_dim=speaker_manager.d_vector_dim) + model = setup_model(c) # restore model checkpoint = torch.load(args.checkpoint_path, map_location="cpu") @@ -283,4 +274,5 @@ if __name__ == "__main__": args = parser.parse_args() c = load_config(args.config_path) + c.audio.trim_silence = False main(args)