diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py new file mode 100755 index 00000000..1ba5b839 --- /dev/null +++ b/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import os +import argparse +import numpy as np +from tqdm import tqdm +import torch + +from torch.utils.data import DataLoader + +from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.utils.generic_utils import setup_model +from TTS.tts.utils.speakers import parse_speakers +from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.utils.io import load_config +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters + +use_cuda = torch.cuda.is_available() + +def setup_loader(ap, r, verbose=False): + dataset = MyDataset( + r, + c.text_cleaner, + compute_linear_spec=False, + meta_data=meta_data, + ap=ap, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0, + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + phoneme_cache_path=c.phoneme_cache_path, + use_phonemes=c.use_phonemes, + phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, + use_noise_augment=False, + verbose=verbose, + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split('.')[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_path = os.path.join(out_path, "wav_gl", file_name+'.wav') + return file_name, wavq_path, mel_path, wav_path + +def format_data(data): + # setup input data + text_input = data[0] + text_lengths = data[1] + speaker_names = data[2] + mel_input = data[4] + mel_lengths = data[5] + item_idx = data[7] + attn_mask = data[9] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + speaker_embeddings = data[8] + speaker_ids = None + else: + speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] + speaker_ids = torch.LongTensor(speaker_ids) + speaker_embeddings = None + else: + speaker_embeddings = None + speaker_ids = None + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if speaker_embeddings is not None: + speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) + + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + speaker_embeddings, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + +@torch.no_grad() +def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, debug=False, metada_name="metada.txt"): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + speaker_embeddings, + _, + _, + attn_mask, + item_idx, + ) = format_data(data) + + if c.model.lower() == "glow_tts": + mel_input = mel_input.permute(0, 2, 1) # B x D x T + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif speaker_embeddings is not None: + speaker_c = speaker_embeddings + + model_output, _, _, _, _, _, _ = model.inference_with_MAS( + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) + model_output = model_output.transpose(1, 2).detach().cpu().numpy() + + elif "tacotron" in c.model.lower(): + if c.bidirectional_decoder or c.double_decoder_consistency: + ( + _, + postnet_outputs, + _, + _, + _, + _, + ) = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) + else: + _, postnet_outputs, _, _ = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) + # normalize tacotron output + if c.model.lower() == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) + model_output = torch.stack(mel_specs) + + elif c.model.lower() == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantized_wav: + wavq = ap.quantize(wav) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + + if debug: + print("Audio for debug saved at:", wav_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_path) + + with open(os.path.join(output_path, metada_name), "w") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, symbols, phonemes, model_characters, speaker_mapping + + # Audio processor + ap = AudioProcessor(**c.audio) + if "characters" in c.keys(): + symbols, phonemes = make_symbols(**c.characters) + + # set model characters + model_characters = phonemes if c.use_phonemes else symbols + num_chars = len(model_characters) + + # load data instances + meta_data_train, meta_data_eval = load_meta_data(c.datasets) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # parse speakers + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None) + + # setup model + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + + # restore model + checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r, verbose=True) + + extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, debug=args.debug, metada_name="metada.txt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--config_path', + type=str, + help='Path to config file for training.', + required=True) + parser.add_argument( + '--checkpoint_path', + type=str, + help='Model file to be restored.', + required=True) + parser.add_argument( + '--output_path', + type=str, + help='Path to save mel specs', + required=True) + parser.add_argument('--debug', + default=False, + action='store_true', + help='Save audio files for debug') + parser.add_argument('--quantized', + action='store_true', + help='Save quantized audio files') + args = parser.parse_args() + + c = load_config(args.config_path) + + main(args) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py old mode 100644 new mode 100755 index c8346c3a..9bfa4296 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -117,7 +117,7 @@ def format_data(data): text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None + linear_input = linear_input.cuda(non_blocking=True) if c.model.lower() in ["tacotron"] else None stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py old mode 100644 new mode 100755 index 59409ad0..f33df3e8 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -497,7 +497,7 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - if c.discriminator_model == 'hifigan_discriminator': + if c.discriminator_model == 'hifigan_discriminator': optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params) else: optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py old mode 100644 new mode 100755 index 0717e2a8..fddd94cf --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -176,6 +176,81 @@ class GlowTTS(nn.Module): attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + @torch.no_grad() + def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + Shapes: + x: [B, T] + x_lenghts: B + y: [B, C, T] + y_lengths: B + g: [B, C] or B + """ + y_max_length = y.size(2) + # norm speaker embeddings + if g is not None: + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + + return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + + @torch.no_grad() + def decoder_inference(self, y, y_lengths=None, g=None): + """ + Shapes: + y: [B, C, T] + y_lengths: B + g: [B, C] or B + """ + y_max_length = y.size(2) + # norm speaker embeddings + if g is not None: + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + + return y, logdet + @torch.no_grad() def inference(self, x, x_lengths, g=None): if g is not None: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py old mode 100644 new mode 100755 index cb2827fd..07061b81 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -22,9 +22,10 @@ def load_speaker_mapping(out_path): def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" - speakers_json_path = make_speakers_json_path(out_path) - with open(speakers_json_path, "w") as f: - json.dump(speaker_mapping, f, indent=4) + if out_path is not None: + speakers_json_path = make_speakers_json_path(out_path) + with open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) def get_speakers(items):