mirror of https://github.com/coqui-ai/TTS.git
Update `TTS.bin` scripts for the new API
This commit is contained in:
parent
aed919cf1c
commit
f3ff5b1971
|
@ -75,7 +75,7 @@ Example run:
|
|||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(num_chars, num_speakers=0, c=C)
|
||||
model = setup_model(C)
|
||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
||||
model.eval()
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def main():
|
|||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
||||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
||||
print(f" > Avg linear spec scale: {linear_scale.mean()}")
|
||||
|
||||
# set default config values for mean-var scaling
|
||||
CONFIG.audio.stats_path = output_file_path
|
||||
|
|
|
@ -31,18 +31,18 @@ c = load_config(config_path)
|
|||
num_speakers = 0
|
||||
|
||||
# init torch model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
model = setup_model(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# init tf model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_tf = Tacotron2(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
out_channels=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
|
|
@ -1,36 +1,24 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.trainer import TrainingArgs
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = TrainingArgs().init_argparse(arg_prefix="")
|
||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default="",
|
||||
required="--config_path" not in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
||||
)
|
||||
args, unargs = parser.parse_known_args()
|
||||
breakpoint()
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||
|
@ -51,6 +39,7 @@ def main():
|
|||
my_env = os.environ.copy()
|
||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[-1] = "--rank={}".format(i)
|
||||
# prevent stdout for processes with rank != 0
|
||||
stdout = None if i == 0 else open(os.devnull, "w")
|
||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||
processes.append(p)
|
||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.datasets import load_meta_data
|
|||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
|
@ -40,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
|
|||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
||||
d_vector_mapping=speaker_manager.d_vectors
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
@ -224,16 +221,10 @@ def extract_spectrograms(
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_manager
|
||||
global meta_data, speaker_manager
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if "characters" in c.keys() and c["characters"]:
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
@ -245,7 +236,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, speaker_manager.num_speakers, c, d_vector_dim=speaker_manager.d_vector_dim)
|
||||
model = setup_model(c)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
|
@ -283,4 +274,5 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
c.audio.trim_silence = False
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue