mirror of https://github.com/coqui-ai/TTS.git
Add debug script
This commit is contained in:
parent
a5f5ebae7e
commit
bb7a645e7a
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Extract Mel spectrograms with teacher forcing."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
from trainer.generic_utils import to_cuda
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
def set_filename(wav_path, out_path):
|
||||
wav_file = os.path.basename(wav_path)
|
||||
file_name = wav_file.split(".")[0]
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav_gt"), exist_ok=True)
|
||||
wavq_path = os.path.join(out_path, "quant", file_name)
|
||||
wav_gt_path = os.path.join(out_path, "wav_gt", file_name + ".wav")
|
||||
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
|
||||
return file_name, wavq_path, wav_gt_path, wav_path
|
||||
|
||||
|
||||
def extract_audios(
|
||||
data_loader, model, ap, output_path, quantized_wav=False, save_gt_audio=False, use_cuda=True
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
|
||||
batch = model.format_batch(batch)
|
||||
batch = model.format_batch_on_device(batch)
|
||||
|
||||
if use_cuda:
|
||||
for k, v in batch.items():
|
||||
batch[k] = to_cuda(v)
|
||||
|
||||
tokens = batch["tokens"]
|
||||
token_lenghts = batch["token_lens"]
|
||||
spec = batch["spec"]
|
||||
spec_lens = batch["spec_lens"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
language_ids = batch["language_ids"]
|
||||
item_idx = batch["audio_files_path"]
|
||||
wav_lengths = batch["waveform_lens"]
|
||||
|
||||
outputs = model.inference_with_MAS(
|
||||
tokens,
|
||||
spec,
|
||||
spec_lens,
|
||||
aux_input={"x_lengths": token_lenghts, "d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.detach().cpu().numpy()
|
||||
|
||||
for idx in range(tokens.shape[0]):
|
||||
wav_file_path = item_idx[idx]
|
||||
wav_gt = ap.load_wav(wav_file_path)
|
||||
|
||||
_, wavq_path, wav_gt_path, wav_path = set_filename(wav_file_path, output_path)
|
||||
|
||||
# quantize and save wav
|
||||
if quantized_wav:
|
||||
wavq = ap.quantize(wav_gt)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
wav = model_output[idx][0]
|
||||
wav_length = wav_lengths[idx]
|
||||
wav = wav[:wav_length]
|
||||
ap.save_wav(wav, wav_path)
|
||||
|
||||
if save_gt_audio:
|
||||
ap.save_wav(wav_gt, wav_gt_path)
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, speaker_manager
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_tts_samples(
|
||||
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
|
||||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# setup model
|
||||
model = setup_model(c, meta_data)
|
||||
|
||||
# restore model
|
||||
model.load_checkpoint(c, args.checkpoint_path, eval=True)
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
own_loader = model.get_data_loader(config=model.config,
|
||||
assets={},
|
||||
is_eval=False,
|
||||
samples=meta_data,
|
||||
verbose=True,
|
||||
num_gpus=1,
|
||||
)
|
||||
|
||||
extract_audios(
|
||||
own_loader,
|
||||
model,
|
||||
ap,
|
||||
args.output_path,
|
||||
quantized_wav=args.quantized,
|
||||
save_gt_audio=args.save_gt_audio,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# python3 TTS/bin/extract_tts_audio.py --config_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/config.json --checkpoint_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/checkpoint_1600000.pth --output_path ../Test_extract_audio_script/
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
||||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||
parser.add_argument("--save_gt_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
c.audio.trim_silence = False
|
||||
c.batch_size = 4
|
||||
main(args)
|
|
@ -215,6 +215,7 @@ class VitsDataset(TTSDataset):
|
|||
"token_len": len(token_ids),
|
||||
"wav": wav,
|
||||
"wav_file": wav_filename,
|
||||
"wav_file_path": item["audio_file"],
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
}
|
||||
|
@ -280,6 +281,7 @@ class VitsDataset(TTSDataset):
|
|||
"speaker_names": batch["speaker_name"],
|
||||
"language_names": batch["language_name"],
|
||||
"audio_files": batch["wav_file"],
|
||||
"audio_files_path": batch["wav_file_path"],
|
||||
"raw_text": batch["raw_text"],
|
||||
}
|
||||
|
||||
|
@ -948,6 +950,70 @@ class Vits(BaseTTS):
|
|||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
def inference_with_MAS(
|
||||
self, x, y, y_lengths, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- d_vectors: :math:`[B, C]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
|
||||
Return Shapes:
|
||||
- model_outputs: :math:`[B, 1, T_wav]`
|
||||
- alignments: :math:`[B, T_seq, T_dec]`
|
||||
- z: :math:`[B, C, T_dec]`
|
||||
- z_p: :math:`[B, C, T_dec]`
|
||||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# get the MAS aligment
|
||||
_, attn = self.forward_mas({}, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
# get predited aligned distribution
|
||||
z_p = m_p * y_mask
|
||||
# reverse the decoder and predict using the aligned distribution
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
if self.args.TTS_part_sample_rate and self.args.interpolate_z:
|
||||
z = z.unsqueeze(0) # pylint: disable=not-callable
|
||||
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
|
||||
y_mask = (
|
||||
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
||||
) # [B, 1, T_dec_resampled]
|
||||
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||
return outputs
|
||||
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
|
@ -1505,7 +1571,8 @@ class Vits(BaseTTS):
|
|||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
|
||||
if self.args.TTS_part_sample_rate is not None and eval:
|
||||
# audio resampler is not used in inference time
|
||||
# ignore audio resampler in load checkpoint to avoid errors
|
||||
audio_resampler_aux = self.audio_resampler
|
||||
self.audio_resampler = None
|
||||
|
||||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
|
@ -1518,11 +1585,15 @@ class Vits(BaseTTS):
|
|||
state["model"]["emb_g.weight"] = emb_g
|
||||
# load the model weights
|
||||
self.load_state_dict(state["model"], strict=strict)
|
||||
# set the audio_resampler
|
||||
if self.args.TTS_part_sample_rate is not None:
|
||||
self.audio_resampler = audio_resampler_aux
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
|
Loading…
Reference in New Issue