Update `TTS.bin` scripts for the new API

This commit is contained in:
Eren Gölge 2021-06-18 13:44:02 +02:00
parent d7225eedb0
commit 45947acb60
5 changed files with 14 additions and 34 deletions

View File

@ -75,7 +75,7 @@ Example run:
# load the model # load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols) num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker # TODO: handle multi-speaker
model = setup_model(num_chars, num_speakers=0, c=C) model = setup_model(C)
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
model.eval() model.eval()

View File

@ -77,7 +77,7 @@ def main():
print(f" > Avg mel spec mean: {mel_mean.mean()}") print(f" > Avg mel spec mean: {mel_mean.mean()}")
print(f" > Avg mel spec scale: {mel_scale.mean()}") print(f" > Avg mel spec scale: {mel_scale.mean()}")
print(f" > Avg linear spec mean: {linear_mean.mean()}") print(f" > Avg linear spec mean: {linear_mean.mean()}")
print(f" > Avg lienar spec scale: {linear_scale.mean()}") print(f" > Avg linear spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling # set default config values for mean-var scaling
CONFIG.audio.stats_path = output_file_path CONFIG.audio.stats_path = output_file_path

View File

@ -31,18 +31,18 @@ c = load_config(config_path)
num_speakers = 0 num_speakers = 0
# init torch model # init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols) model = setup_model(c)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
# init tf model # init tf model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_tf = Tacotron2( model_tf = Tacotron2(
num_chars=num_chars, num_chars=num_chars,
num_speakers=num_speakers, num_speakers=num_speakers,
r=model.decoder.r, r=model.decoder.r,
postnet_output_dim=c.audio["num_mels"], out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type, attn_type=c.attention_type,
attn_win=c.windowing, attn_win=c.windowing,

View File

@ -1,36 +1,24 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse
import os import os
import pathlib import pathlib
import subprocess import subprocess
import sys
import time import time
import torch import torch
from TTS.trainer import TrainingArgs
def main(): def main():
""" """
Call train.py as a new process and pass command arguments Call train.py as a new process and pass command arguments
""" """
parser = argparse.ArgumentParser() parser = TrainingArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.") parser.add_argument("--script", type=str, help="Target training script to distibute.")
parser.add_argument(
"--continue_path",
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default="",
required="--config_path" not in sys.argv,
)
parser.add_argument(
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
)
args, unargs = parser.parse_known_args() args, unargs = parser.parse_known_args()
breakpoint()
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S") group_id = time.strftime("%Y_%m_%d-%H%M%S")
@ -51,6 +39,7 @@ def main():
my_env = os.environ.copy() my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = "--rank={}".format(i) command[-1] = "--rank={}".format(i)
# prevent stdout for processes with rank != 0
stdout = None if i == 0 else open(os.devnull, "w") stdout = None if i == 0 else open(os.devnull, "w")
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p) processes.append(p)

View File

@ -14,7 +14,6 @@ from TTS.tts.datasets import load_meta_data
from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
@ -40,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
use_noise_augment=False, use_noise_augment=False,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids, speaker_id_mapping=speaker_manager.speaker_ids,
d_vector_mapping=speaker_manager.d_vectors d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
) )
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
@ -224,16 +221,10 @@ def extract_spectrograms(
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data, symbols, phonemes, model_characters, speaker_manager global meta_data, speaker_manager
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
if "characters" in c.keys() and c["characters"]:
symbols, phonemes = make_symbols(**c.characters)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances # load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets) meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -245,7 +236,7 @@ def main(args): # pylint: disable=redefined-outer-name
speaker_manager = get_speaker_manager(c, args, meta_data_train) speaker_manager = get_speaker_manager(c, args, meta_data_train)
# setup model # setup model
model = setup_model(num_chars, speaker_manager.num_speakers, c, d_vector_dim=speaker_manager.d_vector_dim) model = setup_model(c)
# restore model # restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu") checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
@ -283,5 +274,5 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
c = load_config(args.config_path) c = load_config(args.config_path)
c.audio["do_trim_silence"] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel c.audio.trim_silence = False
main(args) main(args)