mirror of https://github.com/coqui-ai/TTS.git
Remove AP from FastPitchE2e
This commit is contained in:
parent
4556c61902
commit
231c69b12e
|
@ -1,48 +1,297 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
import numpy as np
|
||||
import pyworld as pw
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.layers.losses import ForwardTTSE2ELoss, VitsDiscriminatorLoss
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.losses import ForwardTTSE2eLoss, VitsDiscriminatorLoss
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
|
||||
from TTS.tts.models.vits import wav_to_mel
|
||||
from TTS.tts.utils.helpers import rand_segments, segment
|
||||
from TTS.tts.models.vits import load_audio, wav_to_mel
|
||||
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy
|
||||
from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return aux_id.cuda()
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||
if cuda:
|
||||
return d_vector.cuda()
|
||||
return d_vector
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
||||
if cuda:
|
||||
return tensor.cuda()
|
||||
return tensor
|
||||
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
||||
|
||||
class ForwardTTSE2eF0Dataset(F0Dataset):
|
||||
"""Override F0Dataset to avoid the AudioProcessor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_config: "AudioConfig",
|
||||
samples: Union[List[List], List[Dict]],
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_f0=True,
|
||||
):
|
||||
self.audio_config = audio_config
|
||||
super().__init__(
|
||||
samples=samples,
|
||||
ap=None,
|
||||
verbose=verbose,
|
||||
cache_path=cache_path,
|
||||
precompute_num_workers=precompute_num_workers,
|
||||
normalize_f0=normalize_f0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
|
||||
wav, _ = load_audio(wav_file)
|
||||
f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax)
|
||||
# skip the last F0 value to align with the spectrogram
|
||||
if wav.shape[1] % config.hop_length != 0:
|
||||
f0 = f0[:-1]
|
||||
if pitch_file:
|
||||
np.save(pitch_file, f0)
|
||||
return f0
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch.astype(np.float32)
|
||||
|
||||
|
||||
class ForwardTTSE2eDataset(TTSDataset):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# don't init the default F0Dataset in TTSDataset
|
||||
compute_f0 = kwargs.pop("compute_f0", False)
|
||||
kwargs["compute_f0"] = False
|
||||
|
||||
self.audio_config = kwargs["audio_config"]
|
||||
del kwargs["audio_config"]
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.compute_f0 = compute_f0
|
||||
self.pad_id = self.tokenizer.characters.pad_id
|
||||
if self.compute_f0:
|
||||
self.f0_dataset = ForwardTTSE2eF0Dataset(
|
||||
audio_config=self.audio_config,
|
||||
samples=self.samples,
|
||||
cache_path=kwargs["f0_cache_path"],
|
||||
precompute_num_workers=kwargs["precompute_num_workers"],
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
raw_text = item["text"]
|
||||
|
||||
wav, _ = load_audio(item["audio_file"])
|
||||
wav_filename = os.path.basename(item["audio_file"])
|
||||
|
||||
token_ids = self.get_token_ids(idx, item["text"])
|
||||
|
||||
f0 = None
|
||||
if self.compute_f0:
|
||||
f0 = self.get_f0(idx)["f0"]
|
||||
|
||||
# after phonemization the text length may change
|
||||
# this is a shameful 🤭 hack to prevent longer phonemes
|
||||
# TODO: find a better fix
|
||||
if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len:
|
||||
self.rescue_item_idx += 1
|
||||
return self.__getitem__(self.rescue_item_idx)
|
||||
|
||||
return {
|
||||
"raw_text": raw_text,
|
||||
"token_ids": token_ids,
|
||||
"token_len": len(token_ids),
|
||||
"wav": wav,
|
||||
"pitch": f0,
|
||||
"wav_file": wav_filename,
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
}
|
||||
|
||||
@property
|
||||
def lengths(self):
|
||||
lens = []
|
||||
for item in self.samples:
|
||||
_, wav_file, *_ = _parse_sample(item)
|
||||
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
lens.append(audio_len)
|
||||
return lens
|
||||
|
||||
def collate_fn(self, batch):
|
||||
"""
|
||||
Return Shapes:
|
||||
- tokens: :math:`[B, T]`
|
||||
- token_lens :math:`[B]`
|
||||
- token_rel_lens :math:`[B]`
|
||||
- pitch :math:`[B, T]`
|
||||
- waveform: :math:`[B, 1, T]`
|
||||
- waveform_lens: :math:`[B]`
|
||||
- waveform_rel_lens: :math:`[B]`
|
||||
- speaker_names: :math:`[B]`
|
||||
- language_names: :math:`[B]`
|
||||
- audiofile_paths: :math:`[B]`
|
||||
- raw_texts: :math:`[B]`
|
||||
"""
|
||||
# convert list of dicts to dict of lists
|
||||
B = len(batch)
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True
|
||||
)
|
||||
|
||||
max_text_len = max([len(x) for x in batch["token_ids"]])
|
||||
token_lens = torch.LongTensor(batch["token_len"])
|
||||
token_rel_lens = token_lens / token_lens.max()
|
||||
|
||||
wav_lens = [w.shape[1] for w in batch["wav"]]
|
||||
wav_lens = torch.LongTensor(wav_lens)
|
||||
wav_lens_max = torch.max(wav_lens)
|
||||
wav_rel_lens = wav_lens / wav_lens_max
|
||||
|
||||
pitch_lens = [p.shape[0] for p in batch["pitch"]]
|
||||
pitch_lens = torch.LongTensor(pitch_lens)
|
||||
pitch_lens_max = torch.max(pitch_lens)
|
||||
|
||||
token_padded = torch.LongTensor(B, max_text_len)
|
||||
wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
|
||||
pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max)
|
||||
|
||||
token_padded = token_padded.zero_() + self.pad_id
|
||||
wav_padded = wav_padded.zero_() + self.pad_id
|
||||
pitch_padded = pitch_padded.zero_() + self.pad_id
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
token_ids = batch["token_ids"][i]
|
||||
token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
|
||||
|
||||
wav = batch["wav"][i]
|
||||
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
||||
|
||||
pitch = batch["pitch"][i]
|
||||
pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch)
|
||||
|
||||
return {
|
||||
"text_input": token_padded,
|
||||
"text_lengths": token_lens,
|
||||
"text_rel_lens": token_rel_lens,
|
||||
"pitch": pitch_padded,
|
||||
"waveform": wav_padded, # (B x T)
|
||||
"waveform_lens": wav_lens, # (B)
|
||||
"waveform_rel_lens": wav_rel_lens,
|
||||
"speaker_names": batch["speaker_name"],
|
||||
"language_names": batch["language_name"],
|
||||
"audio_files": batch["wav_file"],
|
||||
"raw_text": batch["raw_text"],
|
||||
}
|
||||
|
||||
|
||||
##############################
|
||||
# CONFIG DEFINITIONS
|
||||
##############################
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardTTSE2EArgs(ForwardTTSArgs):
|
||||
class ForwardTTSE2eAudio(Coqpit):
|
||||
sample_rate: int = 22050
|
||||
hop_length: int = 256
|
||||
win_length: int = 1024
|
||||
fft_size: int = 1024
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = 8000
|
||||
num_mels: int = 80
|
||||
pitch_fmax: float = 640.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardTTSE2eArgs(ForwardTTSArgs):
|
||||
# vocoder_config: BaseGANVocoderConfig = None
|
||||
num_chars: int = 100
|
||||
encoder_out_channels: int = 80
|
||||
spec_segment_size: int = 32
|
||||
spec_segment_size: int = 80
|
||||
# duration predictor
|
||||
detach_duration_predictor: bool = True
|
||||
duration_predictor_dropout_p: float = 0.1
|
||||
# pitch predictor
|
||||
pitch_predictor_dropout_p: float = 0.1
|
||||
# discriminator
|
||||
init_discriminator: bool = True
|
||||
use_spectral_norm_discriminator: bool = False
|
||||
# model parameters
|
||||
detach_vocoder_input: bool = False
|
||||
hidden_channels: int = 192
|
||||
hidden_channels: int = 256
|
||||
encoder_type: str = "fftransformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
default_factory=lambda: {
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 2,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.1,
|
||||
"kernel_size_fft": 9,
|
||||
}
|
||||
)
|
||||
decoder_type: str = "fftransformer"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
default_factory=lambda: {
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 2,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.1,
|
||||
"kernel_size_fft": 9,
|
||||
}
|
||||
)
|
||||
# generator
|
||||
resblock_type_decoder: str = "1"
|
||||
|
@ -61,35 +310,39 @@ class ForwardTTSE2EArgs(ForwardTTSArgs):
|
|||
d_vector_dim: int = 0
|
||||
|
||||
|
||||
class ForwardTTSE2E(BaseTTSE2E):
|
||||
##############################
|
||||
# MODEL DEFINITION
|
||||
##############################
|
||||
|
||||
|
||||
class ForwardTTSE2e(BaseTTSE2E):
|
||||
"""
|
||||
Model training::
|
||||
text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg
|
||||
spec --------^
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EConfig
|
||||
>>> config = ForwardTTSE2EConfig()
|
||||
>>> model = ForwardTTSE2E(config)
|
||||
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig
|
||||
>>> config = ForwardTTSE2eConfig()
|
||||
>>> model = ForwardTTSE2e(config)
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
super().__init__(config=config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
self._set_model_args(config)
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.encoder_model = ForwardTTS(config=self.args, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
self.encoder_model = ForwardTTS(config=self.args, ap=None, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
# self.vocoder_model = GAN(config=self.args.vocoder_config, ap=ap)
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
self.args.out_channels,
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
self.args.resblock_type_decoder,
|
||||
self.args.resblock_dilation_sizes_decoder,
|
||||
|
@ -179,16 +432,16 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
encoder_outputs = self.encoder_model(
|
||||
x=x, x_lengths=x_lengths, y_lengths=spec_lengths, y=spec, dr=dr, pitch=pitch, aux_input=aux_input
|
||||
)
|
||||
spec_encoder_output = encoder_outputs["model_outputs"]
|
||||
spec_encoder_output_slices, slice_ids = rand_segments(
|
||||
x=spec_encoder_output.transpose(1, 2),
|
||||
o_en_ex = encoder_outputs["o_en_ex"].transpose(1, 2) # [B, C_en, T_max2] -> [B, T_max2, C_en]
|
||||
o_en_ex_slices, slice_ids = rand_segments(
|
||||
x=o_en_ex.transpose(1, 2),
|
||||
x_lengths=spec_lengths,
|
||||
segment_size=self.args.spec_segment_size,
|
||||
let_short_samples=True,
|
||||
pad_short=True,
|
||||
)
|
||||
vocoder_output = self.waveform_decoder(
|
||||
x=spec_encoder_output_slices.detach() if self.args.detach_vocoder_input else spec_encoder_output_slices,
|
||||
x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices,
|
||||
g=encoder_outputs["g"],
|
||||
)
|
||||
wav_seg = segment(
|
||||
|
@ -205,33 +458,35 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
return model_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input)
|
||||
# vocoder_output = self.vocoder_model.model_g(x=encoder_outputs["model_outputs"].transpose(1, 2))
|
||||
vocoder_output = self.waveform_decoder(
|
||||
x=encoder_outputs["model_outputs"].transpose(1, 2), g=encoder_outputs["g"]
|
||||
)
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True)
|
||||
o_en_ex = encoder_outputs["o_en_ex"]
|
||||
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"])
|
||||
model_outputs = {**encoder_outputs}
|
||||
model_outputs["encoder_outputs"] = encoder_outputs["model_outputs"]
|
||||
model_outputs["model_outputs"] = vocoder_output
|
||||
return model_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=False)
|
||||
model_outputs = {**encoder_outputs}
|
||||
return model_outputs
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (ForwardTTSE2EConfig): Model config.
|
||||
config (ForwardTTSE2eConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
# language_manager = LanguageManager.init_from_config(config)
|
||||
return ForwardTTSE2E(config=new_config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -248,7 +503,7 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
token_lenghts = batch["text_lengths"]
|
||||
spec = batch["mel_input"]
|
||||
spec_lens = batch["mel_lengths"]
|
||||
waveform = batch["waveform"].transpose(1, 2) # [B, T, C] -> [B, C, T]
|
||||
waveform = batch["waveform"] # [B, T, C] -> [B, C, T]
|
||||
pitch = batch["pitch"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
@ -316,6 +571,8 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
pitch_output=self.model_outputs_cache["pitch_avg"] if self.args.use_pitch else None,
|
||||
pitch_target=self.model_outputs_cache["pitch_avg_gt"] if self.args.use_pitch else None,
|
||||
input_lens=batch["text_lengths"],
|
||||
waveform=self.model_outputs_cache["waveform_seg"],
|
||||
waveform_hat=self.model_outputs_cache["model_outputs"],
|
||||
aligner_logprob=self.model_outputs_cache["aligner_logprob"],
|
||||
aligner_hard=self.model_outputs_cache["aligner_mas"],
|
||||
aligner_soft=self.model_outputs_cache["aligner_soft"],
|
||||
|
@ -340,29 +597,53 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
@staticmethod
|
||||
def __copy_for_logging(outputs):
|
||||
"""Change keys and copy values for logging."""
|
||||
encoder_outputs = outputs[1].copy()
|
||||
encoder_outputs["model_outputs"] = encoder_outputs["encoder_outputs"]
|
||||
vocoder_outputs = outputs.copy()
|
||||
vocoder_outputs[1]["model_outputs"] = outputs[1]["model_outputs"]
|
||||
return encoder_outputs, vocoder_outputs
|
||||
def _log(self, batch, outputs, name_prefix="train"):
|
||||
figures, audios = {}, {}
|
||||
|
||||
def _log(self, ap, batch, outputs, name_prefix="train"):
|
||||
encoder_outputs, vocoder_outputs = self.__copy_for_logging(outputs)
|
||||
y_hat = vocoder_outputs[1]["model_outputs"]
|
||||
y = vocoder_outputs[1]["waveform_seg"]
|
||||
# encoder outputs
|
||||
encoder_figures, encoder_audios = self.encoder_model.create_logs(
|
||||
batch=batch, outputs=encoder_outputs, ap=self.ap
|
||||
)
|
||||
model_outputs = outputs[1]["encoder_outputs"]
|
||||
alignments = outputs[1]["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, None, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, None, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# plot pitch figures
|
||||
if self.args.use_pitch:
|
||||
pitch_avg = abs(outputs[1]["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
||||
pitch_avg_hat = abs(outputs[1]["pitch_avg"][0, 0].data.cpu().numpy())
|
||||
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||
pitch_figures = {
|
||||
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
||||
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
||||
}
|
||||
figures.update(pitch_figures)
|
||||
|
||||
# plot the attention mask computed from the predicted durations
|
||||
if "attn_durations" in outputs[1]:
|
||||
alignments_hat = outputs[1]["attn_durations"][0].data.cpu().numpy()
|
||||
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio)
|
||||
audios[f"{name_prefix}/encoder_audio"] = encoder_audio
|
||||
|
||||
# vocoder outputs
|
||||
vocoder_figures = plot_results(y_hat, y, ap, name_prefix)
|
||||
y_hat = outputs[1]["model_outputs"]
|
||||
y = outputs[1]["waveform_seg"]
|
||||
|
||||
vocoder_figures = plot_results(y_hat=y_hat, y=y, audio_config=self.config.audio, name_prefix=name_prefix)
|
||||
figures.update(vocoder_figures)
|
||||
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {f"{name_prefix}/real_audio": sample_voice}
|
||||
audios[f"{name_prefix}/encoder_audio"] = encoder_audios["audio"]
|
||||
figures = {**encoder_figures, **vocoder_figures}
|
||||
audios[f"{name_prefix}/real_audio"] = sample_voice
|
||||
return figures, audios
|
||||
|
||||
def train_log(
|
||||
|
@ -374,21 +655,20 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
be projected onto Tensorboard.
|
||||
|
||||
Args:
|
||||
ap (AudioProcessor): audio processor used at training.
|
||||
batch (Dict): Model inputs used at the previous training step.
|
||||
outputs (Dict): Model outputs generated at the previous training step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/")
|
||||
figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.config.audio.sample_rate)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/")
|
||||
figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.config.audio.sample_rate)
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
|
@ -438,6 +718,78 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
"language_name": None,
|
||||
}
|
||||
|
||||
def synthesize(self, text: str, speaker_id, language_id, d_vector):
|
||||
# TODO: add language_id
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
self.tokenizer.text_to_ids(text, language=language_id),
|
||||
dtype=np.int32,
|
||||
)
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
|
||||
# if language_id is not None:
|
||||
# language_id = id_to_torch(language_id, cuda=is_cuda)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
outputs = self.inference(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
||||
|
||||
# collect outputs
|
||||
wav = outputs["model_outputs"][0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": alignments,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
def synthesize_with_gl(self, text: str, speaker_id, language_id, d_vector):
|
||||
# TODO: add language_id
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
self.tokenizer.text_to_ids(text, language=language_id),
|
||||
dtype=np.int32,
|
||||
)
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
|
||||
# if language_id is not None:
|
||||
# language_id = id_to_torch(language_id, cuda=is_cuda)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
||||
|
||||
# collect outputs
|
||||
wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio)
|
||||
alignments = outputs["alignments"]
|
||||
return_dict = {
|
||||
"wav": wav[None, :],
|
||||
"alignments": alignments,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
@ -453,30 +805,147 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
test_sentences = self.config.test_sentences
|
||||
for idx, s_info in enumerate(test_sentences):
|
||||
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
outputs = self.synthesize(
|
||||
aux_inputs["text"],
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
language_id=aux_inputs["language_id"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
)
|
||||
outputs_gl = self.synthesize_with_gl(
|
||||
aux_inputs["text"],
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
language_id=aux_inputs["language_id"],
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs["wav"].T
|
||||
test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
||||
logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate)
|
||||
logger.test_figures(steps, outputs["figures"])
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
|
||||
speaker_ids = None
|
||||
language_ids = None
|
||||
d_vectors = None
|
||||
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
|
||||
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
|
||||
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
batch["speaker_ids"] = speaker_ids
|
||||
|
||||
# get d_vectors from audio file names
|
||||
if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
|
||||
d_vector_mapping = self.speaker_manager.d_vectors
|
||||
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
# get language ids from language names
|
||||
if (
|
||||
self.language_manager is not None
|
||||
and self.language_manager.language_id_mapping
|
||||
and self.args.use_language_embedding
|
||||
):
|
||||
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
|
||||
|
||||
if language_ids is not None:
|
||||
language_ids = torch.LongTensor(language_ids)
|
||||
|
||||
batch["language_ids"] = language_ids
|
||||
batch["d_vectors"] = d_vectors
|
||||
batch["speaker_ids"] = speaker_ids
|
||||
return batch
|
||||
|
||||
def format_batch_on_device(self, batch):
|
||||
"""Compute spectrograms on the device."""
|
||||
ac = self.config.audio
|
||||
|
||||
# compute spectrograms
|
||||
batch["mel_input"] = wav_to_mel(
|
||||
batch["waveform"],
|
||||
hop_length=ac.hop_length,
|
||||
win_length=ac.win_length,
|
||||
n_fft=ac.fft_size,
|
||||
num_mels=ac.num_mels,
|
||||
sample_rate=ac.sample_rate,
|
||||
fmin=ac.mel_fmin,
|
||||
fmax=ac.mel_fmax,
|
||||
center=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
batch["pitch"].shape[2] == batch["mel_input"].shape[2]
|
||||
), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
|
||||
# zero the padding frames
|
||||
batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1)
|
||||
batch["mel_input"] = batch["mel_input"].transpose(1, 2)
|
||||
return batch
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# init dataloader
|
||||
dataset = ForwardTTSE2eDataset(
|
||||
samples=samples,
|
||||
audio_config=self.config.audio,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_text_len=config.min_text_len,
|
||||
max_text_len=config.max_text_len,
|
||||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
precompute_num_workers=config.precompute_num_workers,
|
||||
compute_f0=config.compute_f0,
|
||||
f0_cache_path=config.f0_cache_path,
|
||||
verbose=verbose,
|
||||
tokenizer=self.tokenizer,
|
||||
start_by_longest=config.start_by_longest,
|
||||
)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
return [VitsDiscriminatorLoss(self.config), ForwardTTSE2ELoss(self.config)]
|
||||
return [VitsDiscriminatorLoss(self.config), ForwardTTSE2eLoss(self.config)]
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
@ -516,3 +985,12 @@ class ForwardTTSE2E(BaseTTSE2E):
|
|||
"""Schedule binary loss weight."""
|
||||
self.encoder_model.config.binary_loss_warmup_epochs = self.config.binary_loss_warmup_epochs
|
||||
self.encoder_model.on_train_step_start(trainer)
|
||||
|
||||
def on_init_start(self, trainer: "Trainer"):
|
||||
self.__mel_basis = build_mel_basis(
|
||||
sample_rate=self.config.audio.sample_rate,
|
||||
fft_size=self.config.audio.fft_size,
|
||||
num_mels=self.config.audio.num_mels,
|
||||
mel_fmax=self.config.audio.mel_fmax,
|
||||
mel_fmin=self.config.audio.mel_fmin,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue