diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py new file mode 100755 index 00000000..d5c23ccd --- /dev/null +++ b/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import os +import argparse +import numpy as np +from tqdm import tqdm +import torch + +from torch.utils.data import DataLoader + +from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.utils.generic_utils import setup_model +from TTS.tts.utils.speakers import parse_speakers +from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.utils.io import load_config +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters + +use_cuda = torch.cuda.is_available() + +def setup_loader(ap, r, verbose=False): + dataset = MyDataset( + r, + c.text_cleaner, + compute_linear_spec=False, + meta_data=meta_data, + ap=ap, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0, + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + phoneme_cache_path=c.phoneme_cache_path, + use_phonemes=c.use_phonemes, + phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, + use_noise_augment=False, + verbose=verbose, + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split('.')[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav') + wav_path = os.path.join(out_path, "wav", file_name+'.wav') + return file_name, wavq_path, mel_path, wav_gl_path, wav_path + +def format_data(data): + # setup input data + text_input = data[0] + text_lengths = data[1] + speaker_names = data[2] + mel_input = data[4] + mel_lengths = data[5] + item_idx = data[7] + attn_mask = data[9] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + speaker_embeddings = data[8] + speaker_ids = None + else: + speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] + speaker_ids = torch.LongTensor(speaker_ids) + speaker_embeddings = None + else: + speaker_embeddings = None + speaker_ids = None + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if speaker_embeddings is not None: + speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) + + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + speaker_embeddings, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + +@torch.no_grad() +def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): + if model_name == "glow_tts": + mel_input = mel_input.permute(0, 2, 1) # B x D x T + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif speaker_embeddings is not None: + speaker_c = speaker_embeddings + + model_output, *_ = model.inference_with_MAS( + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) + model_output = model_output.transpose(1, 2).detach().cpu().numpy() + + elif "tacotron" in model_name: + _, postnet_outputs, *_ = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + +def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + speaker_embeddings, + _, + _, + attn_mask, + item_idx, + ) = format_data(data) + + model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantized_wav: + wavq = ap.quantize(wav) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + if save_audio: + ap.save_wav(wav, wav_path) + + if debug: + print("Audio for debug saved at:", wav_gl_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_gl_path) + + with open(os.path.join(output_path, metada_name), "w") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, symbols, phonemes, model_characters, speaker_mapping + + # Audio processor + ap = AudioProcessor(**c.audio) + if "characters" in c.keys(): + symbols, phonemes = make_symbols(**c.characters) + + # set model characters + model_characters = phonemes if c.use_phonemes else symbols + num_chars = len(model_characters) + + # load data instances + meta_data_train, meta_data_eval = load_meta_data(c.datasets) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # parse speakers + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None) + + # setup model + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + + # restore model + checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r, verbose=True) + + extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--config_path', + type=str, + help='Path to config file for training.', + required=True) + parser.add_argument( + '--checkpoint_path', + type=str, + help='Model file to be restored.', + required=True) + parser.add_argument( + '--output_path', + type=str, + help='Path to save mel specs', + required=True) + parser.add_argument('--debug', + default=False, + action='store_true', + help='Save audio files for debug') + parser.add_argument('--save_audio', + default=False, + action='store_true', + help='Save audio files') + parser.add_argument('--quantized', + action='store_true', + help='Save quantized audio files') + args = parser.parse_args() + + c = load_config(args.config_path) + + main(args) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py old mode 100644 new mode 100755 index c8346c3a..9bfa4296 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -117,7 +117,7 @@ def format_data(data): text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None + linear_input = linear_input.cuda(non_blocking=True) if c.model.lower() in ["tacotron"] else None stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py old mode 100644 new mode 100755 index 730506c1..f33df3e8 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -5,6 +5,7 @@ import os import sys import time +import itertools import traceback from inspect import signature @@ -495,7 +496,11 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) + + if c.discriminator_model == 'hifigan_discriminator': + optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params) + else: + optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py old mode 100644 new mode 100755 index 1b84b451..19eb594a --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -179,6 +179,81 @@ class GlowTTS(nn.Module): attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + @torch.no_grad() + def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + Shapes: + x: [B, T] + x_lenghts: B + y: [B, C, T] + y_lengths: B + g: [B, C] or B + """ + y_max_length = y.size(2) + # norm speaker embeddings + if g is not None: + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + + return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + + @torch.no_grad() + def decoder_inference(self, y, y_lengths=None, g=None): + """ + Shapes: + y: [B, C, T] + y_lengths: B + g: [B, C] or B + """ + y_max_length = y.size(2) + # norm speaker embeddings + if g is not None: + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + + return y, logdet + @torch.no_grad() def inference(self, x, x_lengths, g=None): if g is not None: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py old mode 100644 new mode 100755 index 583b46a8..b80e8ee3 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -28,9 +28,10 @@ def load_speaker_mapping(out_path): def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" - speakers_json_path = make_speakers_json_path(out_path) - with open(speakers_json_path, "w") as f: - json.dump(speaker_mapping, f, indent=4) + if out_path is not None: + speakers_json_path = make_speakers_json_path(out_path) + with open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) def get_speakers(items): diff --git a/tests/test_extract_tts_spectrograms.py b/tests/test_extract_tts_spectrograms.py new file mode 100644 index 00000000..65db9c0e --- /dev/null +++ b/tests/test_extract_tts_spectrograms.py @@ -0,0 +1,66 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path + +from tests import get_tests_output_path, run_cli + +from TTS.tts.utils.generic_utils import setup_model + +from TTS.utils.io import load_config +from TTS.tts.utils.text.symbols import phonemes, symbols + +torch.manual_seed(1) + +# pylint: disable=protected-access +class TestExtractTTSSpectrograms(unittest.TestCase): + @staticmethod + def test_GlowTTS(): + # set paths + config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json") + checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') + output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') + # load config + c = load_config(config_path) + # create model + num_chars = len(phonemes if c.use_phonemes else symbols) + model = setup_model(num_chars, 1, c, speaker_embedding_dim=None) + # save model + torch.save({"model": model.state_dict()}, checkpoint_path) + # run test + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') + run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') + @staticmethod + def test_Tacotron2(): + # set paths + config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json") + checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') + output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') + # load config + c = load_config(config_path) + # create model + num_chars = len(phonemes if c.use_phonemes else symbols) + model = setup_model(num_chars, 1, c, speaker_embedding_dim=None) + # save model + torch.save({"model": model.state_dict()}, checkpoint_path) + # run test + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') + run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') + @staticmethod + def test_Tacotron(): + # set paths + config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json") + checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar') + output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/') + # load config + c = load_config(config_path) + # create model + num_chars = len(phonemes if c.use_phonemes else symbols) + model = setup_model(num_chars, 1, c, speaker_embedding_dim=None) + # save model + torch.save({"model": model.state_dict()}, checkpoint_path) + # run test + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"') + run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"') diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index 8e699faf..7e17ed45 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -129,3 +129,58 @@ class GlowTTSTrainTest(unittest.TestCase): count, param.shape, param, param_ref ) count += 1 + +class GlowTTSInferenceTest(unittest.TestCase): + @staticmethod + def test_inference(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + + # create model + model = GlowTTS( + num_chars=32, + hidden_channels_enc=48, + hidden_channels_dec=48, + hidden_channels_dp=32, + out_channels=80, + encoder_type="rel_pos_transformer", + encoder_params={ + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 16, # 4 times the hidden_channels + "input_length": None, + }, + use_encoder_prenet=True, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.0, + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_squeeze=1, + sigmoid_scale=False, + mean_only=False, + ).to(device) + + model.eval() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + + # inference encoder and decoder with MAS + y, *_ = model.inference_with_MAS( + input_dummy, input_lengths, mel_spec, mel_lengths, None + ) + + y_dec, _ = model.decoder_inference(mel_spec, mel_lengths + ) + + assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( + y.shape, y_dec.shape + ) diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index e3ed8ae2..72b47d23 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -37,6 +37,7 @@ class TacotronTrainTest(unittest.TestCase): mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) @@ -96,6 +97,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase): mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device)