mirror of https://github.com/coqui-ai/TTS.git
Update `TTS.bin` scripts for the new API
This commit is contained in:
parent
d7225eedb0
commit
45947acb60
|
@ -75,7 +75,7 @@ Example run:
|
||||||
# load the model
|
# load the model
|
||||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
# TODO: handle multi-speaker
|
# TODO: handle multi-speaker
|
||||||
model = setup_model(num_chars, num_speakers=0, c=C)
|
model = setup_model(C)
|
||||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ def main():
|
||||||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
||||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
||||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
||||||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
print(f" > Avg linear spec scale: {linear_scale.mean()}")
|
||||||
|
|
||||||
# set default config values for mean-var scaling
|
# set default config values for mean-var scaling
|
||||||
CONFIG.audio.stats_path = output_file_path
|
CONFIG.audio.stats_path = output_file_path
|
||||||
|
|
|
@ -31,18 +31,18 @@ c = load_config(config_path)
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
|
|
||||||
# init torch model
|
# init torch model
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
model = setup_model(c)
|
||||||
model = setup_model(num_chars, num_speakers, c)
|
|
||||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# init tf model
|
# init tf model
|
||||||
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
model_tf = Tacotron2(
|
model_tf = Tacotron2(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
r=model.decoder.r,
|
r=model.decoder.r,
|
||||||
postnet_output_dim=c.audio["num_mels"],
|
out_channels=c.audio["num_mels"],
|
||||||
decoder_output_dim=c.audio["num_mels"],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
attn_type=c.attention_type,
|
attn_type=c.attention_type,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
|
|
|
@ -1,36 +1,24 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from TTS.trainer import TrainingArgs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
Call train.py as a new process and pass command arguments
|
Call train.py as a new process and pass command arguments
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = TrainingArgs().init_argparse(arg_prefix="")
|
||||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||||
parser.add_argument(
|
|
||||||
"--continue_path",
|
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default="",
|
|
||||||
required="--config_path" not in sys.argv,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
|
||||||
)
|
|
||||||
args, unargs = parser.parse_known_args()
|
args, unargs = parser.parse_known_args()
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||||
|
@ -51,6 +39,7 @@ def main():
|
||||||
my_env = os.environ.copy()
|
my_env = os.environ.copy()
|
||||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||||
command[-1] = "--rank={}".format(i)
|
command[-1] = "--rank={}".format(i)
|
||||||
|
# prevent stdout for processes with rank != 0
|
||||||
stdout = None if i == 0 else open(os.devnull, "w")
|
stdout = None if i == 0 else open(os.devnull, "w")
|
||||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
|
@ -40,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
speaker_id_mapping=speaker_manager.speaker_ids,
|
||||||
d_vector_mapping=speaker_manager.d_vectors
|
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
|
||||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
|
@ -224,16 +221,10 @@ def extract_spectrograms(
|
||||||
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data, symbols, phonemes, model_characters, speaker_manager
|
global meta_data, speaker_manager
|
||||||
|
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if "characters" in c.keys() and c["characters"]:
|
|
||||||
symbols, phonemes = make_symbols(**c.characters)
|
|
||||||
|
|
||||||
# set model characters
|
|
||||||
model_characters = phonemes if c.use_phonemes else symbols
|
|
||||||
num_chars = len(model_characters)
|
|
||||||
|
|
||||||
# load data instances
|
# load data instances
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||||
|
@ -245,7 +236,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
model = setup_model(num_chars, speaker_manager.num_speakers, c, d_vector_dim=speaker_manager.d_vector_dim)
|
model = setup_model(c)
|
||||||
|
|
||||||
# restore model
|
# restore model
|
||||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||||
|
@ -283,5 +274,5 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
c.audio["do_trim_silence"] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel
|
c.audio.trim_silence = False
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue