diff --git a/TTS/bin/extract_tts_audio.py b/TTS/bin/extract_tts_audio.py new file mode 100755 index 00000000..de020fbe --- /dev/null +++ b/TTS/bin/extract_tts_audio.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import argparse +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters +from trainer.generic_utils import to_cuda + +use_cuda = torch.cuda.is_available() + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split(".")[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gt"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + wav_gt_path = os.path.join(out_path, "wav_gt", file_name + ".wav") + wav_path = os.path.join(out_path, "wav", file_name + ".wav") + return file_name, wavq_path, wav_gt_path, wav_path + + +def extract_audios( + data_loader, model, ap, output_path, quantized_wav=False, save_gt_audio=False, use_cuda=True +): + model.eval() + export_metadata = [] + for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)): + + batch = model.format_batch(batch) + batch = model.format_batch_on_device(batch) + + if use_cuda: + for k, v in batch.items(): + batch[k] = to_cuda(v) + + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + spec_lens = batch["spec_lens"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + item_idx = batch["audio_files_path"] + wav_lengths = batch["waveform_lens"] + + outputs = model.inference_with_MAS( + tokens, + spec, + spec_lens, + aux_input={"x_lengths": token_lenghts, "d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + model_output = outputs["model_outputs"] + model_output = model_output.detach().cpu().numpy() + + for idx in range(tokens.shape[0]): + wav_file_path = item_idx[idx] + wav_gt = ap.load_wav(wav_file_path) + + _, wavq_path, wav_gt_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantized_wav: + wavq = ap.quantize(wav_gt) + np.save(wavq_path, wavq) + + # save TTS mel + wav = model_output[idx][0] + wav_length = wav_lengths[idx] + wav = wav[:wav_length] + ap.save_wav(wav, wav_path) + + if save_gt_audio: + ap.save_wav(wav_gt, wav_gt_path) + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, speaker_manager + + # Audio processor + ap = AudioProcessor(**c.audio) + + # load data instances + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # setup model + model = setup_model(c, meta_data) + + # restore model + model.load_checkpoint(c, args.checkpoint_path, eval=True) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + own_loader = model.get_data_loader(config=model.config, + assets={}, + is_eval=False, + samples=meta_data, + verbose=True, + num_gpus=1, + ) + + extract_audios( + own_loader, + model, + ap, + args.output_path, + quantized_wav=args.quantized, + save_gt_audio=args.save_gt_audio, + use_cuda=use_cuda, + ) + + +if __name__ == "__main__": + # python3 TTS/bin/extract_tts_audio.py --config_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/config.json --checkpoint_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/checkpoint_1600000.pth --output_path ../Test_extract_audio_script/ + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) + parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) + parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) + parser.add_argument("--save_gt_audio", default=False, action="store_true", help="Save audio files") + parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() + + c = load_config(args.config_path) + c.audio.trim_silence = False + c.batch_size = 4 + main(args) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0a65632a..b2f638f8 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -215,6 +215,7 @@ class VitsDataset(TTSDataset): "token_len": len(token_ids), "wav": wav, "wav_file": wav_filename, + "wav_file_path": item["audio_file"], "speaker_name": item["speaker_name"], "language_name": item["language"], } @@ -280,6 +281,7 @@ class VitsDataset(TTSDataset): "speaker_names": batch["speaker_name"], "language_names": batch["language_name"], "audio_files": batch["wav_file"], + "audio_files_path": batch["wav_file_path"], "raw_text": batch["raw_text"], } @@ -948,6 +950,70 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + def inference_with_MAS( + self, x, y, y_lengths, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # get the MAS aligment + _, attn = self.forward_mas({}, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + # get predited aligned distribution + z_p = m_p * y_mask + # reverse the decoder and predict using the aligned distribution + z = self.flow(z_p, y_mask, g=g, reverse=True) + + if self.args.TTS_part_sample_rate and self.args.interpolate_z: + z = z.unsqueeze(0) # pylint: disable=not-callable + z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0) + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} + return outputs + def inference( self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} ): # pylint: disable=dangerous-default-value @@ -1505,7 +1571,8 @@ class Vits(BaseTTS): state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} if self.args.TTS_part_sample_rate is not None and eval: - # audio resampler is not used in inference time + # ignore audio resampler in load checkpoint to avoid errors + audio_resampler_aux = self.audio_resampler self.audio_resampler = None # handle fine-tuning from a checkpoint with additional speakers @@ -1518,11 +1585,15 @@ class Vits(BaseTTS): state["model"]["emb_g.weight"] = emb_g # load the model weights self.load_state_dict(state["model"], strict=strict) + # set the audio_resampler + if self.args.TTS_part_sample_rate is not None: + self.audio_resampler = audio_resampler_aux if eval: self.eval() assert not self.training + @staticmethod def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config