mirror of https://github.com/coqui-ai/TTS.git
styling
This commit is contained in:
parent
8b1014d188
commit
12722501bb
|
@ -1 +1 @@
|
||||||
__version__ = '0.0.13.2'
|
__version__ = "0.0.13.2"
|
||||||
|
|
|
@ -1,25 +1,26 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Extract Mel spectrograms with teacher forcing."""
|
"""Extract Mel spectrograms with teacher forcing."""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.config import load_config
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, verbose=False):
|
def setup_loader(ap, r, verbose=False):
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
r,
|
r,
|
||||||
|
@ -38,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_mapping=speaker_mapping
|
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
|
||||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
|
@ -60,19 +59,21 @@ def setup_loader(ap, r, verbose=False):
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def set_filename(wav_path, out_path):
|
def set_filename(wav_path, out_path):
|
||||||
wav_file = os.path.basename(wav_path)
|
wav_file = os.path.basename(wav_path)
|
||||||
file_name = wav_file.split('.')[0]
|
file_name = wav_file.split(".")[0]
|
||||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
||||||
wavq_path = os.path.join(out_path, "quant", file_name)
|
wavq_path = os.path.join(out_path, "quant", file_name)
|
||||||
mel_path = os.path.join(out_path, "mel", file_name)
|
mel_path = os.path.join(out_path, "mel", file_name)
|
||||||
wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav')
|
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
|
||||||
wav_path = os.path.join(out_path, "wav", file_name+'.wav')
|
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
|
||||||
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
||||||
|
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
|
@ -123,8 +124,20 @@ def format_data(data):
|
||||||
item_idx,
|
item_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
|
def inference(
|
||||||
|
model_name,
|
||||||
|
model,
|
||||||
|
ap,
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
attn_mask=None,
|
||||||
|
speaker_ids=None,
|
||||||
|
speaker_embeddings=None,
|
||||||
|
):
|
||||||
if model_name == "glow_tts":
|
if model_name == "glow_tts":
|
||||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||||
speaker_c = None
|
speaker_c = None
|
||||||
|
@ -140,7 +153,13 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
|
||||||
|
|
||||||
elif "tacotron" in model_name:
|
elif "tacotron" in model_name:
|
||||||
_, postnet_outputs, *_ = model(
|
_, postnet_outputs, *_ = model(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
speaker_ids=speaker_ids,
|
||||||
|
speaker_embeddings=speaker_embeddings,
|
||||||
|
)
|
||||||
# normalize tacotron output
|
# normalize tacotron output
|
||||||
if model_name == "tacotron":
|
if model_name == "tacotron":
|
||||||
mel_specs = []
|
mel_specs = []
|
||||||
|
@ -154,7 +173,10 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
|
||||||
model_output = postnet_outputs.detach().cpu().numpy()
|
model_output = postnet_outputs.detach().cpu().numpy()
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
|
|
||||||
|
def extract_spectrograms(
|
||||||
|
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
||||||
|
):
|
||||||
model.eval()
|
model.eval()
|
||||||
export_metadata = []
|
export_metadata = []
|
||||||
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||||
|
@ -173,7 +195,18 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
|
||||||
item_idx,
|
item_idx,
|
||||||
) = format_data(data)
|
) = format_data(data)
|
||||||
|
|
||||||
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
|
model_output = inference(
|
||||||
|
c.model.lower(),
|
||||||
|
model,
|
||||||
|
ap,
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
attn_mask,
|
||||||
|
speaker_ids,
|
||||||
|
speaker_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
for idx in range(text_input.shape[0]):
|
for idx in range(text_input.shape[0]):
|
||||||
wav_file_path = item_idx[idx]
|
wav_file_path = item_idx[idx]
|
||||||
|
@ -204,13 +237,14 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
|
||||||
for data in export_metadata:
|
for data in export_metadata:
|
||||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||||
|
|
||||||
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
||||||
|
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if "characters" in c.keys() and c['characters']:
|
if "characters" in c.keys() and c["characters"]:
|
||||||
symbols, phonemes = make_symbols(**c.characters)
|
symbols, phonemes = make_symbols(**c.characters)
|
||||||
|
|
||||||
# set model characters
|
# set model characters
|
||||||
|
@ -242,37 +276,26 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
||||||
own_loader = setup_loader(ap, r, verbose=True)
|
own_loader = setup_loader(ap, r, verbose=True)
|
||||||
|
|
||||||
extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt")
|
extract_spectrograms(
|
||||||
|
own_loader,
|
||||||
|
model,
|
||||||
|
ap,
|
||||||
|
args.output_path,
|
||||||
|
quantized_wav=args.quantized,
|
||||||
|
save_audio=args.save_audio,
|
||||||
|
debug=args.debug,
|
||||||
|
metada_name="metada.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
||||||
'--config_path',
|
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
||||||
type=str,
|
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||||
help='Path to config file for training.',
|
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||||
required=True)
|
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||||
parser.add_argument(
|
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||||
'--checkpoint_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored.',
|
|
||||||
required=True)
|
|
||||||
parser.add_argument(
|
|
||||||
'--output_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to save mel specs',
|
|
||||||
required=True)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help='Save audio files for debug')
|
|
||||||
parser.add_argument('--save_audio',
|
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help='Save audio files')
|
|
||||||
parser.add_argument('--quantized',
|
|
||||||
action='store_true',
|
|
||||||
help='Save quantized audio files')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
# TODO: mixed precision training
|
# TODO: mixed precision training
|
||||||
"""Trains GAN based vocoder model."""
|
"""Trains GAN based vocoder model."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import itertools
|
|
||||||
import traceback
|
import traceback
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
|
@ -496,8 +496,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||||
|
|
||||||
if c.discriminator_model == 'hifigan_discriminator':
|
if c.discriminator_model == "hifigan_discriminator":
|
||||||
optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params)
|
optimizer_disc = optimizer_disc(
|
||||||
|
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
|
||||||
|
lr=c.lr_disc,
|
||||||
|
**c.optimizer_params,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from coqpit import MISSING, Coqpit, check_argument
|
from coqpit import MISSING, Coqpit, check_argument
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,11 +152,14 @@ class BaseDatasetConfig(Coqpit):
|
||||||
Path to the file that lists the attention mask files used with models that require attention masks to
|
Path to the file that lists the attention mask files used with models that require attention masks to
|
||||||
train the duration predictor.
|
train the duration predictor.
|
||||||
"""
|
"""
|
||||||
name: str = ''
|
|
||||||
path: str = ''
|
name: str = ""
|
||||||
meta_file_train: Union[str, List] = '' # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
|
path: str = ""
|
||||||
meta_file_val: str = ''
|
meta_file_train: Union[
|
||||||
meta_file_attn_mask: str = ''
|
str, List
|
||||||
|
] = "" # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
|
||||||
|
meta_file_val: str = ""
|
||||||
|
meta_file_attn_mask: str = ""
|
||||||
|
|
||||||
def check_values(
|
def check_values(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -59,7 +59,6 @@ class GlowTTSConfig(BaseTTSConfig):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
model: str = "glow_tts"
|
model: str = "glow_tts"
|
||||||
|
|
||||||
# model params
|
# model params
|
||||||
|
|
|
@ -22,6 +22,7 @@ class GSTConfig(Coqpit):
|
||||||
gst_num_style_tokens (int):
|
gst_num_style_tokens (int):
|
||||||
Number of style token vectors. Defaults to 10.
|
Number of style token vectors. Defaults to 10.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gst_style_input_wav: str = None
|
gst_style_input_wav: str = None
|
||||||
gst_style_input_weights: dict = None
|
gst_style_input_weights: dict = None
|
||||||
gst_embedding_dim: int = 256
|
gst_embedding_dim: int = 256
|
||||||
|
|
|
@ -113,4 +113,4 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 13
|
min_seq_len: int = 13
|
||||||
max_seq_len: int = 200
|
max_seq_len: int = 200
|
||||||
r: int = 1 #DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
|
|
@ -90,7 +90,6 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
|
||||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
model: str = "multiband_melgan"
|
model: str = "multiband_melgan"
|
||||||
|
|
||||||
# Model specific params
|
# Model specific params
|
||||||
|
|
|
@ -66,6 +66,7 @@ class WavegradConfig(BaseVocoderConfig):
|
||||||
lr_scheduler_params (dict):
|
lr_scheduler_params (dict):
|
||||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
|
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "wavegrad"
|
model: str = "wavegrad"
|
||||||
# Model specific params
|
# Model specific params
|
||||||
generator_model: str = "wavegrad"
|
generator_model: str = "wavegrad"
|
||||||
|
|
|
@ -3,13 +3,9 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
||||||
|
|
||||||
from tests import get_tests_output_path, run_cli
|
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
@ -20,8 +16,8 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
def test_GlowTTS():
|
def test_GlowTTS():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
# create model
|
# create model
|
||||||
|
@ -30,14 +26,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
# save model
|
# save model
|
||||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||||
# run test
|
# run test
|
||||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
run_cli(
|
||||||
|
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||||
|
)
|
||||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_Tacotron2():
|
def test_Tacotron2():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
# create model
|
# create model
|
||||||
|
@ -46,14 +45,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
# save model
|
# save model
|
||||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||||
# run test
|
# run test
|
||||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
run_cli(
|
||||||
|
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||||
|
)
|
||||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_Tacotron():
|
def test_Tacotron():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
# create model
|
# create model
|
||||||
|
@ -62,5 +64,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
# save model
|
# save model
|
||||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||||
# run test
|
# run test
|
||||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
run_cli(
|
||||||
|
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||||
|
)
|
||||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||||
|
|
|
@ -130,6 +130,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class GlowTTSInferenceTest(unittest.TestCase):
|
class GlowTTSInferenceTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_inference():
|
def test_inference():
|
||||||
|
@ -174,13 +175,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
||||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||||
|
|
||||||
# inference encoder and decoder with MAS
|
# inference encoder and decoder with MAS
|
||||||
y, *_ = model.inference_with_MAS(
|
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
|
||||||
)
|
|
||||||
|
|
||||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths
|
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
|
||||||
)
|
|
||||||
|
|
||||||
assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
assert (
|
||||||
|
y_dec.shape == y.shape
|
||||||
|
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||||
y.shape, y_dec.shape
|
y.shape, y_dec.shape
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue