This commit is contained in:
Eren Gölge 2021-05-15 23:48:31 +02:00
parent 8b1014d188
commit 12722501bb
11 changed files with 115 additions and 80 deletions

View File

@ -1 +1 @@
__version__ = '0.0.13.2' __version__ = "0.0.13.2"

View File

@ -1,25 +1,26 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Extract Mel spectrograms with teacher forcing.""" """Extract Mel spectrograms with teacher forcing."""
import os
import argparse import argparse
import os
import numpy as np import numpy as np
from tqdm import tqdm
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.config import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False): def setup_loader(ap, r, verbose=False):
dataset = MyDataset( dataset = MyDataset(
r, r,
@ -38,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars, enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False, use_noise_augment=False,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_mapping speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
) )
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
@ -60,19 +59,21 @@ def setup_loader(ap, r, verbose=False):
) )
return loader return loader
def set_filename(wav_path, out_path): def set_filename(wav_path, out_path):
wav_file = os.path.basename(wav_path) wav_file = os.path.basename(wav_path)
file_name = wav_file.split('.')[0] file_name = wav_file.split(".")[0]
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
wavq_path = os.path.join(out_path, "quant", file_name) wavq_path = os.path.join(out_path, "quant", file_name)
mel_path = os.path.join(out_path, "mel", file_name) mel_path = os.path.join(out_path, "mel", file_name)
wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav') wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
wav_path = os.path.join(out_path, "wav", file_name+'.wav') wav_path = os.path.join(out_path, "wav", file_name + ".wav")
return file_name, wavq_path, mel_path, wav_gl_path, wav_path return file_name, wavq_path, mel_path, wav_gl_path, wav_path
def format_data(data): def format_data(data):
# setup input data # setup input data
text_input = data[0] text_input = data[0]
@ -123,8 +124,20 @@ def format_data(data):
item_idx, item_idx,
) )
@torch.no_grad() @torch.no_grad()
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): def inference(
model_name,
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask=None,
speaker_ids=None,
speaker_embeddings=None,
):
if model_name == "glow_tts": if model_name == "glow_tts":
mel_input = mel_input.permute(0, 2, 1) # B x D x T mel_input = mel_input.permute(0, 2, 1) # B x D x T
speaker_c = None speaker_c = None
@ -140,7 +153,13 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
elif "tacotron" in model_name: elif "tacotron" in model_name:
_, postnet_outputs, *_ = model( _, postnet_outputs, *_ = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
# normalize tacotron output # normalize tacotron output
if model_name == "tacotron": if model_name == "tacotron":
mel_specs = [] mel_specs = []
@ -154,7 +173,10 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
model_output = postnet_outputs.detach().cpu().numpy() model_output = postnet_outputs.detach().cpu().numpy()
return model_output return model_output
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
def extract_spectrograms(
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
):
model.eval() model.eval()
export_metadata = [] export_metadata = []
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
@ -173,7 +195,18 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
item_idx, item_idx,
) = format_data(data) ) = format_data(data)
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) model_output = inference(
c.model.lower(),
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask,
speaker_ids,
speaker_embeddings,
)
for idx in range(text_input.shape[0]): for idx in range(text_input.shape[0]):
wav_file_path = item_idx[idx] wav_file_path = item_idx[idx]
@ -204,13 +237,14 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
for data in export_metadata: for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n") f.write(f"{data[0]}|{data[1]+'.npy'}\n")
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data, symbols, phonemes, model_characters, speaker_mapping global meta_data, symbols, phonemes, model_characters, speaker_mapping
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
if "characters" in c.keys() and c['characters']: if "characters" in c.keys() and c["characters"]:
symbols, phonemes = make_symbols(**c.characters) symbols, phonemes = make_symbols(**c.characters)
# set model characters # set model characters
@ -242,37 +276,26 @@ def main(args): # pylint: disable=redefined-outer-name
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True) own_loader = setup_loader(ap, r, verbose=True)
extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt") extract_spectrograms(
own_loader,
model,
ap,
args.output_path,
quantized_wav=args.quantized,
save_audio=args.save_audio,
debug=args.debug,
metada_name="metada.txt",
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
'--config_path', parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
type=str, parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
help='Path to config file for training.', parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
required=True) parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument( parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
'--checkpoint_path',
type=str,
help='Model file to be restored.',
required=True)
parser.add_argument(
'--output_path',
type=str,
help='Path to save mel specs',
required=True)
parser.add_argument('--debug',
default=False,
action='store_true',
help='Save audio files for debug')
parser.add_argument('--save_audio',
default=False,
action='store_true',
help='Save audio files')
parser.add_argument('--quantized',
action='store_true',
help='Save quantized audio files')
args = parser.parse_args() args = parser.parse_args()
c = load_config(args.config_path) c = load_config(args.config_path)

View File

@ -2,10 +2,10 @@
# TODO: mixed precision training # TODO: mixed precision training
"""Trains GAN based vocoder model.""" """Trains GAN based vocoder model."""
import itertools
import os import os
import sys import sys
import time import time
import itertools
import traceback import traceback
from inspect import signature from inspect import signature
@ -496,8 +496,12 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
optimizer_disc = getattr(torch.optim, c.optimizer) optimizer_disc = getattr(torch.optim, c.optimizer)
if c.discriminator_model == 'hifigan_discriminator': if c.discriminator_model == "hifigan_discriminator":
optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params) optimizer_disc = optimizer_disc(
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
lr=c.lr_disc,
**c.optimizer_params,
)
else: else:
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)

View File

@ -1,5 +1,6 @@
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import List, Union from typing import List, Union
from coqpit import MISSING, Coqpit, check_argument from coqpit import MISSING, Coqpit, check_argument
@ -151,11 +152,14 @@ class BaseDatasetConfig(Coqpit):
Path to the file that lists the attention mask files used with models that require attention masks to Path to the file that lists the attention mask files used with models that require attention masks to
train the duration predictor. train the duration predictor.
""" """
name: str = ''
path: str = '' name: str = ""
meta_file_train: Union[str, List] = '' # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat. path: str = ""
meta_file_val: str = '' meta_file_train: Union[
meta_file_attn_mask: str = '' str, List
] = "" # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
meta_file_val: str = ""
meta_file_attn_mask: str = ""
def check_values( def check_values(
self, self,

View File

@ -59,7 +59,6 @@ class GlowTTSConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
""" """
model: str = "glow_tts" model: str = "glow_tts"
# model params # model params

View File

@ -22,6 +22,7 @@ class GSTConfig(Coqpit):
gst_num_style_tokens (int): gst_num_style_tokens (int):
Number of style token vectors. Defaults to 10. Number of style token vectors. Defaults to 10.
""" """
gst_style_input_wav: str = None gst_style_input_wav: str = None
gst_style_input_weights: dict = None gst_style_input_weights: dict = None
gst_embedding_dim: int = 256 gst_embedding_dim: int = 256

View File

@ -113,4 +113,4 @@ class SpeedySpeechConfig(BaseTTSConfig):
# overrides # overrides
min_seq_len: int = 13 min_seq_len: int = 13
max_seq_len: int = 200 max_seq_len: int = 200
r: int = 1 #DO NOT CHANGE r: int = 1 # DO NOT CHANGE

View File

@ -90,7 +90,6 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
""" """
model: str = "multiband_melgan" model: str = "multiband_melgan"
# Model specific params # Model specific params

View File

@ -66,6 +66,7 @@ class WavegradConfig(BaseVocoderConfig):
lr_scheduler_params (dict): lr_scheduler_params (dict):
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}` kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
""" """
model: str = "wavegrad" model: str = "wavegrad"
# Model specific params # Model specific params
generator_model: str = "wavegrad" generator_model: str = "wavegrad"

View File

@ -3,13 +3,9 @@ import unittest
import torch import torch
from tests import get_tests_input_path from tests import get_tests_input_path, get_tests_output_path, run_cli
from tests import get_tests_output_path, run_cli
from TTS.tts.utils.generic_utils import setup_model
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.tts.utils.text.symbols import phonemes, symbols
torch.manual_seed(1) torch.manual_seed(1)
@ -20,8 +16,8 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
def test_GlowTTS(): def test_GlowTTS():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json") config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)
# create model # create model
@ -30,14 +26,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model # save model
torch.save({"model": model.state_dict()}, checkpoint_path) torch.save({"model": model.state_dict()}, checkpoint_path)
# run test # run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
@staticmethod @staticmethod
def test_Tacotron2(): def test_Tacotron2():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json") config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)
# create model # create model
@ -46,14 +45,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model # save model
torch.save({"model": model.state_dict()}, checkpoint_path) torch.save({"model": model.state_dict()}, checkpoint_path)
# run test # run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
@staticmethod @staticmethod
def test_Tacotron(): def test_Tacotron():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json") config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)
# create model # create model
@ -62,5 +64,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
# save model # save model
torch.save({"model": model.state_dict()}, checkpoint_path) torch.save({"model": model.state_dict()}, checkpoint_path)
# run test # run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
)
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')

View File

@ -130,6 +130,7 @@ class GlowTTSTrainTest(unittest.TestCase):
) )
count += 1 count += 1
class GlowTTSInferenceTest(unittest.TestCase): class GlowTTSInferenceTest(unittest.TestCase):
@staticmethod @staticmethod
def test_inference(): def test_inference():
@ -174,13 +175,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS # inference encoder and decoder with MAS
y, *_ = model.inference_with_MAS( y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
input_dummy, input_lengths, mel_spec, mel_lengths, None
)
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
)
assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( assert (
y_dec.shape == y.shape
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
y.shape, y_dec.shape y.shape, y_dec.shape
) )