Update `TTS.bin` scripts for the new API

This commit is contained in:
Eren Gölge 2021-06-18 13:44:02 +02:00
parent aed919cf1c
commit f3ff5b1971
5 changed files with 14 additions and 33 deletions

View File

@ -75,7 +75,7 @@ Example run:
# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker
model = setup_model(num_chars, num_speakers=0, c=C)
model = setup_model(C)
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
model.eval()

View File

@ -77,7 +77,7 @@ def main():
print(f" > Avg mel spec mean: {mel_mean.mean()}")
print(f" > Avg mel spec scale: {mel_scale.mean()}")
print(f" > Avg linear spec mean: {linear_mean.mean()}")
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
print(f" > Avg linear spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling
CONFIG.audio.stats_path = output_file_path

View File

@ -31,18 +31,18 @@ c = load_config(config_path)
num_speakers = 0
# init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
model = setup_model(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
# init tf model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_tf = Tacotron2(
num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
postnet_output_dim=c.audio["num_mels"],
out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,

View File

@ -1,36 +1,24 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import pathlib
import subprocess
import sys
import time
import torch
from TTS.trainer import TrainingArgs
def main():
"""
Call train.py as a new process and pass command arguments
"""
parser = argparse.ArgumentParser()
parser = TrainingArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.")
parser.add_argument(
"--continue_path",
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default="",
required="--config_path" not in sys.argv,
)
parser.add_argument(
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
)
args, unargs = parser.parse_known_args()
breakpoint()
num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S")
@ -51,6 +39,7 @@ def main():
my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = "--rank={}".format(i)
# prevent stdout for processes with rank != 0
stdout = None if i == 0 else open(os.devnull, "w")
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p)

View File

@ -14,7 +14,6 @@ from TTS.tts.datasets import load_meta_data
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
@ -40,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids,
d_vector_mapping=speaker_manager.d_vectors
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
@ -224,16 +221,10 @@ def extract_spectrograms(
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data, symbols, phonemes, model_characters, speaker_manager
global meta_data, speaker_manager
# Audio processor
ap = AudioProcessor(**c.audio)
if "characters" in c.keys() and c["characters"]:
symbols, phonemes = make_symbols(**c.characters)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -245,7 +236,7 @@ def main(args): # pylint: disable=redefined-outer-name
speaker_manager = get_speaker_manager(c, args, meta_data_train)
# setup model
model = setup_model(num_chars, speaker_manager.num_speakers, c, d_vector_dim=speaker_manager.d_vector_dim)
model = setup_model(c)
# restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
@ -283,4 +274,5 @@ if __name__ == "__main__":
args = parser.parse_args()
c = load_config(args.config_path)
c.audio.trim_silence = False
main(args)