This commit is contained in:
Eren Gölge 2021-05-15 23:48:31 +02:00
parent 8b1014d188
commit 12722501bb
11 changed files with 115 additions and 80 deletions

View File

@ -1 +1 @@
__version__ = '0.0.13.2'
__version__ = "0.0.13.2"

View File

@ -1,25 +1,26 @@
#!/usr/bin/env python3
"""Extract Mel spectrograms with teacher forcing."""
import os
import argparse
import os
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.config import load_config
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
dataset = MyDataset(
r,
@ -38,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False,
verbose=verbose,
speaker_mapping=speaker_mapping
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
@ -60,19 +59,21 @@ def setup_loader(ap, r, verbose=False):
)
return loader
def set_filename(wav_path, out_path):
wav_file = os.path.basename(wav_path)
file_name = wav_file.split('.')[0]
file_name = wav_file.split(".")[0]
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
wavq_path = os.path.join(out_path, "quant", file_name)
mel_path = os.path.join(out_path, "mel", file_name)
wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav')
wav_path = os.path.join(out_path, "wav", file_name+'.wav')
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
def format_data(data):
# setup input data
text_input = data[0]
@ -123,10 +124,22 @@ def format_data(data):
item_idx,
)
@torch.no_grad()
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
def inference(
model_name,
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask=None,
speaker_ids=None,
speaker_embeddings=None,
):
if model_name == "glow_tts":
mel_input = mel_input.permute(0, 2, 1) # B x D x T
mel_input = mel_input.permute(0, 2, 1) # B x D x T
speaker_c = None
if speaker_ids is not None:
speaker_c = speaker_ids
@ -140,7 +153,13 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
elif "tacotron" in model_name:
_, postnet_outputs, *_ = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
# normalize tacotron output
if model_name == "tacotron":
mel_specs = []
@ -154,7 +173,10 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
model_output = postnet_outputs.detach().cpu().numpy()
return model_output
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
def extract_spectrograms(
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
):
model.eval()
export_metadata = []
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
@ -173,7 +195,18 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
item_idx,
) = format_data(data)
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
model_output = inference(
c.model.lower(),
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask,
speaker_ids,
speaker_embeddings,
)
for idx in range(text_input.shape[0]):
wav_file_path = item_idx[idx]
@ -204,13 +237,14 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if "characters" in c.keys() and c['characters']:
if "characters" in c.keys() and c["characters"]:
symbols, phonemes = make_symbols(**c.characters)
# set model characters
@ -242,37 +276,26 @@ def main(args): # pylint: disable=redefined-outer-name
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True)
extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt")
extract_spectrograms(
own_loader,
model,
ap,
args.output_path,
quantized_wav=args.quantized,
save_audio=args.save_audio,
debug=args.debug,
metada_name="metada.txt",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required=True)
parser.add_argument(
'--checkpoint_path',
type=str,
help='Model file to be restored.',
required=True)
parser.add_argument(
'--output_path',
type=str,
help='Path to save mel specs',
required=True)
parser.add_argument('--debug',
default=False,
action='store_true',
help='Save audio files for debug')
parser.add_argument('--save_audio',
default=False,
action='store_true',
help='Save audio files')
parser.add_argument('--quantized',
action='store_true',
help='Save quantized audio files')
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
args = parser.parse_args()
c = load_config(args.config_path)

View File

@ -2,10 +2,10 @@
# TODO: mixed precision training
"""Trains GAN based vocoder model."""
import itertools
import os
import sys
import time
import itertools
import traceback
from inspect import signature
@ -496,8 +496,12 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
optimizer_disc = getattr(torch.optim, c.optimizer)
if c.discriminator_model == 'hifigan_discriminator':
optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params)
if c.discriminator_model == "hifigan_discriminator":
optimizer_disc = optimizer_disc(
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
lr=c.lr_disc,
**c.optimizer_params,
)
else:
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)

View File

@ -1,5 +1,6 @@
from dataclasses import asdict, dataclass
from typing import List, Union
from coqpit import MISSING, Coqpit, check_argument
@ -151,11 +152,14 @@ class BaseDatasetConfig(Coqpit):
Path to the file that lists the attention mask files used with models that require attention masks to
train the duration predictor.
"""
name: str = ''
path: str = ''
meta_file_train: Union[str, List] = '' # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
meta_file_val: str = ''
meta_file_attn_mask: str = ''
name: str = ""
path: str = ""
meta_file_train: Union[
str, List
] = "" # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
meta_file_val: str = ""
meta_file_attn_mask: str = ""
def check_values(
self,

View File

@ -59,7 +59,6 @@ class GlowTTSConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "glow_tts"
# model params

View File

@ -22,6 +22,7 @@ class GSTConfig(Coqpit):
gst_num_style_tokens (int):
Number of style token vectors. Defaults to 10.
"""
gst_style_input_wav: str = None
gst_style_input_weights: dict = None
gst_embedding_dim: int = 256
@ -130,7 +131,7 @@ class BaseTTSConfig(BaseTrainingConfig):
datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training.
"""
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# phoneme settings

View File

@ -113,4 +113,4 @@ class SpeedySpeechConfig(BaseTTSConfig):
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 #DO NOT CHANGE
r: int = 1 # DO NOT CHANGE

View File

@ -90,7 +90,6 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "multiband_melgan"
# Model specific params
@ -142,4 +141,4 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0
l1_spec_loss_weight: float = 0

View File

@ -66,6 +66,7 @@ class WavegradConfig(BaseVocoderConfig):
lr_scheduler_params (dict):
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
"""
model: str = "wavegrad"
# Model specific params
generator_model: str = "wavegrad"

View File

@ -3,13 +3,9 @@ import unittest
import torch
from tests import get_tests_input_path
from tests import get_tests_output_path, run_cli
from TTS.tts.utils.generic_utils import setup_model
from tests import get_tests_input_path, get_tests_output_path, run_cli
from TTS.config import load_config
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols
torch.manual_seed(1)
@ -20,8 +16,8 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
def test_GlowTTS():
# set paths
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config
c = load_config(config_path)
# create model
@ -30,14 +26,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model
torch.save({"model": model.state_dict()}, checkpoint_path)
# run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
@staticmethod
def test_Tacotron2():
# set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config
c = load_config(config_path)
# create model
@ -46,14 +45,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model
torch.save({"model": model.state_dict()}, checkpoint_path)
# run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
@staticmethod
def test_Tacotron():
# set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config
c = load_config(config_path)
# create model
@ -62,5 +64,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model
torch.save({"model": model.state_dict()}, checkpoint_path)
# run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')

View File

@ -130,6 +130,7 @@ class GlowTTSTrainTest(unittest.TestCase):
)
count += 1
class GlowTTSInferenceTest(unittest.TestCase):
@staticmethod
def test_inference():
@ -174,13 +175,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS
y, *_ = model.inference_with_MAS(
input_dummy, input_lengths, mel_spec, mel_lengths, None
)
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths
)
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
y.shape, y_dec.shape
)
assert (
y_dec.shape == y.shape
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
y.shape, y_dec.shape
)