mirror of https://github.com/coqui-ai/TTS.git
1026 lines
40 KiB
Python
1026 lines
40 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
from itertools import chain
|
|
from typing import Dict, List, Tuple, Union
|
|
|
|
import numpy as np
|
|
import pyworld as pw
|
|
import torch
|
|
import torch.distributed as dist
|
|
from coqpit import Coqpit
|
|
from torch import nn
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.utils.data import DataLoader
|
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
|
|
|
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
|
from TTS.tts.layers.losses import ForwardTTSE2eLoss, VitsDiscriminatorLoss
|
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
|
from TTS.tts.models.base_tts import BaseTTSE2E
|
|
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
|
|
from TTS.tts.models.vits import load_audio, wav_to_mel
|
|
from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
|
|
from TTS.tts.utils.speakers import SpeakerManager
|
|
from TTS.tts.utils.synthesis import synthesis
|
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
|
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
|
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
|
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
|
from TTS.vocoder.utils.generic_utils import plot_results
|
|
|
|
|
|
def id_to_torch(aux_id, cuda=False):
|
|
if aux_id is not None:
|
|
aux_id = np.asarray(aux_id)
|
|
aux_id = torch.from_numpy(aux_id)
|
|
if cuda:
|
|
return aux_id.cuda()
|
|
return aux_id
|
|
|
|
|
|
def embedding_to_torch(d_vector, cuda=False):
|
|
if d_vector is not None:
|
|
d_vector = np.asarray(d_vector)
|
|
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
|
d_vector = d_vector.squeeze().unsqueeze(0)
|
|
if cuda:
|
|
return d_vector.cuda()
|
|
return d_vector
|
|
|
|
|
|
def numpy_to_torch(np_array, dtype, cuda=False):
|
|
if np_array is None:
|
|
return None
|
|
tensor = torch.as_tensor(np_array, dtype=dtype)
|
|
if cuda:
|
|
return tensor.cuda()
|
|
return tensor
|
|
|
|
|
|
##############################
|
|
# DATASET
|
|
##############################
|
|
|
|
|
|
class ForwardTTSE2eF0Dataset(F0Dataset):
|
|
"""Override F0Dataset to avoid the AudioProcessor."""
|
|
|
|
def __init__(
|
|
self,
|
|
audio_config: "AudioConfig",
|
|
samples: Union[List[List], List[Dict]],
|
|
verbose=False,
|
|
cache_path: str = None,
|
|
precompute_num_workers=0,
|
|
normalize_f0=True,
|
|
):
|
|
self.audio_config = audio_config
|
|
super().__init__(
|
|
samples=samples,
|
|
ap=None,
|
|
verbose=verbose,
|
|
cache_path=cache_path,
|
|
precompute_num_workers=precompute_num_workers,
|
|
normalize_f0=normalize_f0,
|
|
)
|
|
|
|
@staticmethod
|
|
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
|
|
wav, _ = load_audio(wav_file)
|
|
f0 = compute_f0(
|
|
x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax
|
|
)
|
|
# skip the last F0 value to align with the spectrogram
|
|
if wav.shape[1] % config.hop_length != 0:
|
|
f0 = f0[:-1]
|
|
if pitch_file:
|
|
np.save(pitch_file, f0)
|
|
return f0
|
|
|
|
def compute_or_load(self, wav_file):
|
|
"""
|
|
compute pitch and return a numpy array of pitch values
|
|
"""
|
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
|
if not os.path.exists(pitch_file):
|
|
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
|
|
else:
|
|
pitch = np.load(pitch_file)
|
|
return pitch.astype(np.float32)
|
|
|
|
|
|
class ForwardTTSE2eDataset(TTSDataset):
|
|
def __init__(self, *args, **kwargs):
|
|
# don't init the default F0Dataset in TTSDataset
|
|
compute_f0 = kwargs.pop("compute_f0", False)
|
|
kwargs["compute_f0"] = False
|
|
|
|
self.audio_config = kwargs["audio_config"]
|
|
del kwargs["audio_config"]
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.compute_f0 = compute_f0
|
|
self.pad_id = self.tokenizer.characters.pad_id
|
|
if self.compute_f0:
|
|
self.f0_dataset = ForwardTTSE2eF0Dataset(
|
|
audio_config=self.audio_config,
|
|
samples=self.samples,
|
|
cache_path=kwargs["f0_cache_path"],
|
|
precompute_num_workers=kwargs["precompute_num_workers"],
|
|
)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.samples[idx]
|
|
raw_text = item["text"]
|
|
|
|
wav, _ = load_audio(item["audio_file"])
|
|
wav_filename = os.path.basename(item["audio_file"])
|
|
|
|
token_ids = self.get_token_ids(idx, item["text"])
|
|
|
|
f0 = None
|
|
if self.compute_f0:
|
|
f0 = self.get_f0(idx)["f0"]
|
|
|
|
# after phonemization the text length may change
|
|
# this is a shameful 🤭 hack to prevent longer phonemes
|
|
# TODO: find a better fix
|
|
if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len:
|
|
self.rescue_item_idx += 1
|
|
return self.__getitem__(self.rescue_item_idx)
|
|
|
|
return {
|
|
"raw_text": raw_text,
|
|
"token_ids": token_ids,
|
|
"token_len": len(token_ids),
|
|
"wav": wav,
|
|
"pitch": f0,
|
|
"wav_file": wav_filename,
|
|
"speaker_name": item["speaker_name"],
|
|
"language_name": item["language"],
|
|
}
|
|
|
|
@property
|
|
def lengths(self):
|
|
lens = []
|
|
for item in self.samples:
|
|
_, wav_file, *_ = _parse_sample(item)
|
|
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
|
lens.append(audio_len)
|
|
return lens
|
|
|
|
def collate_fn(self, batch):
|
|
"""
|
|
Return Shapes:
|
|
- tokens: :math:`[B, T]`
|
|
- token_lens :math:`[B]`
|
|
- token_rel_lens :math:`[B]`
|
|
- pitch :math:`[B, T]`
|
|
- waveform: :math:`[B, 1, T]`
|
|
- waveform_lens: :math:`[B]`
|
|
- waveform_rel_lens: :math:`[B]`
|
|
- speaker_names: :math:`[B]`
|
|
- language_names: :math:`[B]`
|
|
- audiofile_paths: :math:`[B]`
|
|
- raw_texts: :math:`[B]`
|
|
"""
|
|
# convert list of dicts to dict of lists
|
|
B = len(batch)
|
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
|
|
|
_, ids_sorted_decreasing = torch.sort(
|
|
torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True
|
|
)
|
|
|
|
max_text_len = max([len(x) for x in batch["token_ids"]])
|
|
token_lens = torch.LongTensor(batch["token_len"])
|
|
token_rel_lens = token_lens / token_lens.max()
|
|
|
|
wav_lens = [w.shape[1] for w in batch["wav"]]
|
|
wav_lens = torch.LongTensor(wav_lens)
|
|
wav_lens_max = torch.max(wav_lens)
|
|
wav_rel_lens = wav_lens / wav_lens_max
|
|
|
|
pitch_lens = [p.shape[0] for p in batch["pitch"]]
|
|
pitch_lens = torch.LongTensor(pitch_lens)
|
|
pitch_lens_max = torch.max(pitch_lens)
|
|
|
|
token_padded = torch.LongTensor(B, max_text_len)
|
|
wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
|
|
pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max)
|
|
|
|
token_padded = token_padded.zero_() + self.pad_id
|
|
wav_padded = wav_padded.zero_() + self.pad_id
|
|
pitch_padded = pitch_padded.zero_() + self.pad_id
|
|
|
|
for i in range(len(ids_sorted_decreasing)):
|
|
token_ids = batch["token_ids"][i]
|
|
token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
|
|
|
|
wav = batch["wav"][i]
|
|
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
|
|
|
pitch = batch["pitch"][i]
|
|
pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch)
|
|
|
|
return {
|
|
"text_input": token_padded,
|
|
"text_lengths": token_lens,
|
|
"text_rel_lens": token_rel_lens,
|
|
"pitch": pitch_padded,
|
|
"waveform": wav_padded, # (B x T)
|
|
"waveform_lens": wav_lens, # (B)
|
|
"waveform_rel_lens": wav_rel_lens,
|
|
"speaker_names": batch["speaker_name"],
|
|
"language_names": batch["language_name"],
|
|
"audio_files": batch["wav_file"],
|
|
"raw_text": batch["raw_text"],
|
|
}
|
|
|
|
|
|
##############################
|
|
# CONFIG DEFINITIONS
|
|
##############################
|
|
|
|
|
|
@dataclass
|
|
class ForwardTTSE2eAudio(Coqpit):
|
|
sample_rate: int = 22050
|
|
hop_length: int = 256
|
|
win_length: int = 1024
|
|
fft_size: int = 1024
|
|
mel_fmin: float = 0.0
|
|
mel_fmax: float = 8000
|
|
num_mels: int = 80
|
|
pitch_fmax: float = 640.0
|
|
|
|
|
|
@dataclass
|
|
class ForwardTTSE2eArgs(ForwardTTSArgs):
|
|
# vocoder_config: BaseGANVocoderConfig = None
|
|
num_chars: int = 100
|
|
encoder_out_channels: int = 80
|
|
spec_segment_size: int = 80
|
|
# duration predictor
|
|
detach_duration_predictor: bool = True
|
|
duration_predictor_dropout_p: float = 0.1
|
|
# pitch predictor
|
|
pitch_predictor_dropout_p: float = 0.1
|
|
# discriminator
|
|
init_discriminator: bool = True
|
|
use_spectral_norm_discriminator: bool = False
|
|
# model parameters
|
|
detach_vocoder_input: bool = False
|
|
hidden_channels: int = 256
|
|
encoder_type: str = "fftransformer"
|
|
encoder_params: dict = field(
|
|
default_factory=lambda: {
|
|
"hidden_channels_ffn": 1024,
|
|
"num_heads": 2,
|
|
"num_layers": 4,
|
|
"dropout_p": 0.1,
|
|
"kernel_size_fft": 9,
|
|
}
|
|
)
|
|
decoder_type: str = "fftransformer"
|
|
decoder_params: dict = field(
|
|
default_factory=lambda: {
|
|
"hidden_channels_ffn": 1024,
|
|
"num_heads": 2,
|
|
"num_layers": 4,
|
|
"dropout_p": 0.1,
|
|
"kernel_size_fft": 9,
|
|
}
|
|
)
|
|
# generator
|
|
resblock_type_decoder: str = "1"
|
|
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
|
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
|
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
|
upsample_initial_channel_decoder: int = 512
|
|
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
|
# multi-speaker params
|
|
use_speaker_embedding: bool = False
|
|
num_speakers: int = 0
|
|
speakers_file: str = None
|
|
d_vector_file: str = None
|
|
speaker_embedding_channels: int = 384
|
|
use_d_vector_file: bool = False
|
|
d_vector_dim: int = 0
|
|
|
|
|
|
##############################
|
|
# MODEL DEFINITION
|
|
##############################
|
|
|
|
|
|
class ForwardTTSE2e(BaseTTSE2E):
|
|
"""
|
|
Model training::
|
|
text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg
|
|
spec --------^
|
|
|
|
Examples:
|
|
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig
|
|
>>> config = ForwardTTSE2eConfig()
|
|
>>> model = ForwardTTSE2e(config)
|
|
"""
|
|
|
|
# pylint: disable=dangerous-default-value
|
|
def __init__(
|
|
self,
|
|
config: Coqpit,
|
|
tokenizer: "TTSTokenizer" = None,
|
|
speaker_manager: SpeakerManager = None,
|
|
):
|
|
super().__init__(config=config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
|
self._set_model_args(config)
|
|
|
|
self.init_multispeaker(config)
|
|
|
|
self.encoder_model = ForwardTTS(config=self.args, ap=None, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
|
# self.vocoder_model = GAN(config=self.args.vocoder_config, ap=ap)
|
|
self.waveform_decoder = HifiganGenerator(
|
|
self.args.hidden_channels,
|
|
1,
|
|
self.args.resblock_type_decoder,
|
|
self.args.resblock_dilation_sizes_decoder,
|
|
self.args.resblock_kernel_sizes_decoder,
|
|
self.args.upsample_kernel_sizes_decoder,
|
|
self.args.upsample_initial_channel_decoder,
|
|
self.args.upsample_rates_decoder,
|
|
inference_padding=0,
|
|
cond_channels=self.embedded_speaker_dim,
|
|
conv_pre_weight_norm=False,
|
|
conv_post_weight_norm=False,
|
|
conv_post_bias=False,
|
|
)
|
|
|
|
# use Vits Discriminator for limiting VRAM use
|
|
if self.args.init_discriminator:
|
|
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_discriminator)
|
|
|
|
def init_multispeaker(self, config: Coqpit):
|
|
"""Init for multi-speaker training.
|
|
|
|
Args:
|
|
config (Coqpit): Model configuration.
|
|
"""
|
|
self.embedded_speaker_dim = 0
|
|
self.num_speakers = self.args.num_speakers
|
|
self.audio_transform = None
|
|
|
|
if self.speaker_manager:
|
|
self.num_speakers = self.speaker_manager.num_speakers
|
|
|
|
if self.args.use_speaker_embedding:
|
|
self._init_speaker_embedding()
|
|
|
|
if self.args.use_d_vector_file:
|
|
self._init_d_vector()
|
|
|
|
def _init_speaker_embedding(self):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
if self.num_speakers > 0:
|
|
print(" > initialization of speaker-embedding layers.")
|
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
|
|
|
def _init_d_vector(self):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
if hasattr(self, "emb_g"):
|
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
|
self.embedded_speaker_dim = self.args.d_vector_dim
|
|
|
|
def get_aux_input(self, *args, **kwargs) -> Dict:
|
|
return self.encoder_model.get_aux_input(*args, **kwargs)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.LongTensor,
|
|
x_lengths: torch.LongTensor,
|
|
spec_lengths: torch.LongTensor,
|
|
spec: torch.FloatTensor,
|
|
waveform: torch.FloatTensor,
|
|
dr: torch.IntTensor = None,
|
|
pitch: torch.FloatTensor = None,
|
|
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
|
|
) -> Dict:
|
|
"""Model's forward pass.
|
|
|
|
Args:
|
|
x (torch.LongTensor): Input character sequences.
|
|
x_lengths (torch.LongTensor): Input sequence lengths.
|
|
spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None.
|
|
spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
|
waveform (torch.FloatTensor): Waveform. Defaults to None.
|
|
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
|
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
|
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
|
|
|
Shapes:
|
|
- x: :math:`[B, T_max]`
|
|
- x_lengths: :math:`[B]`
|
|
- spec_lengths: :math:`[B]`
|
|
- spec: :math:`[B, T_max2]`
|
|
- waveform: :math:`[B, C, T_max2]`
|
|
- dr: :math:`[B, T_max]`
|
|
- g: :math:`[B, C]`
|
|
- pitch: :math:`[B, 1, T]`
|
|
"""
|
|
encoder_outputs = self.encoder_model(
|
|
x=x, x_lengths=x_lengths, y_lengths=spec_lengths, y=spec, dr=dr, pitch=pitch, aux_input=aux_input
|
|
)
|
|
o_en_ex = encoder_outputs["o_en_ex"].transpose(1, 2) # [B, C_en, T_max2] -> [B, T_max2, C_en]
|
|
o_en_ex_slices, slice_ids = rand_segments(
|
|
x=o_en_ex.transpose(1, 2),
|
|
x_lengths=spec_lengths,
|
|
segment_size=self.args.spec_segment_size,
|
|
let_short_samples=True,
|
|
pad_short=True,
|
|
)
|
|
vocoder_output = self.waveform_decoder(
|
|
x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices,
|
|
g=encoder_outputs["g"],
|
|
)
|
|
wav_seg = segment(
|
|
waveform,
|
|
slice_ids * self.config.audio.hop_length,
|
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
|
pad_short=True,
|
|
)
|
|
model_outputs = {**encoder_outputs}
|
|
model_outputs["encoder_outputs"] = encoder_outputs["model_outputs"]
|
|
model_outputs["model_outputs"] = vocoder_output
|
|
model_outputs["waveform_seg"] = wav_seg
|
|
model_outputs["slice_ids"] = slice_ids
|
|
return model_outputs
|
|
|
|
@torch.no_grad()
|
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
|
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True)
|
|
o_en_ex = encoder_outputs["o_en_ex"]
|
|
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"])
|
|
model_outputs = {**encoder_outputs}
|
|
model_outputs["model_outputs"] = vocoder_output
|
|
return model_outputs
|
|
|
|
@torch.no_grad()
|
|
def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
|
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=False)
|
|
model_outputs = {**encoder_outputs}
|
|
return model_outputs
|
|
|
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
|
if optimizer_idx == 0:
|
|
tokens = batch["text_input"]
|
|
token_lenghts = batch["text_lengths"]
|
|
spec = batch["mel_input"]
|
|
spec_lens = batch["mel_lengths"]
|
|
waveform = batch["waveform"] # [B, T, C] -> [B, C, T]
|
|
pitch = batch["pitch"]
|
|
d_vectors = batch["d_vectors"]
|
|
speaker_ids = batch["speaker_ids"]
|
|
language_ids = batch["language_ids"]
|
|
|
|
# generator pass
|
|
outputs = self.forward(
|
|
x=tokens,
|
|
x_lengths=token_lenghts,
|
|
spec_lengths=spec_lens,
|
|
spec=spec,
|
|
waveform=waveform,
|
|
pitch=pitch,
|
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
|
)
|
|
|
|
# cache tensors for the generator pass
|
|
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
|
|
|
# compute scores and features
|
|
scores_d_fake, _, scores_d_real, _ = self.disc(outputs["model_outputs"].detach(), outputs["waveform_seg"])
|
|
|
|
# compute loss
|
|
with autocast(enabled=False): # use float32 for the criterion
|
|
loss_dict = criterion[optimizer_idx](
|
|
scores_disc_fake=scores_d_fake,
|
|
scores_disc_real=scores_d_real,
|
|
)
|
|
return outputs, loss_dict
|
|
|
|
if optimizer_idx == 1:
|
|
mel = batch["mel_input"].transpose(1, 2)
|
|
|
|
# compute melspec segment
|
|
with autocast(enabled=False):
|
|
mel_slice = segment(
|
|
mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True
|
|
)
|
|
|
|
mel_slice_hat = wav_to_mel(
|
|
y=self.model_outputs_cache["model_outputs"].float(),
|
|
n_fft=self.config.audio.fft_size,
|
|
sample_rate=self.config.audio.sample_rate,
|
|
num_mels=self.config.audio.num_mels,
|
|
hop_length=self.config.audio.hop_length,
|
|
win_length=self.config.audio.win_length,
|
|
fmin=self.config.audio.mel_fmin,
|
|
fmax=self.config.audio.mel_fmax,
|
|
center=False,
|
|
)
|
|
|
|
# compute discriminator scores and features
|
|
scores_d_fake, feats_d_fake, _, feats_d_real = self.disc(
|
|
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
|
)
|
|
|
|
# compute losses
|
|
with autocast(enabled=False): # use float32 for the criterion
|
|
loss_dict = criterion[optimizer_idx](
|
|
decoder_output=self.model_outputs_cache["encoder_outputs"],
|
|
decoder_target=batch["mel_input"],
|
|
decoder_output_lens=batch["mel_lengths"],
|
|
dur_output=self.model_outputs_cache["durations_log"],
|
|
dur_target=self.model_outputs_cache["aligner_durations"],
|
|
pitch_output=self.model_outputs_cache["pitch_avg"] if self.args.use_pitch else None,
|
|
pitch_target=self.model_outputs_cache["pitch_avg_gt"] if self.args.use_pitch else None,
|
|
input_lens=batch["text_lengths"],
|
|
waveform=self.model_outputs_cache["waveform_seg"],
|
|
waveform_hat=self.model_outputs_cache["model_outputs"],
|
|
aligner_logprob=self.model_outputs_cache["aligner_logprob"],
|
|
aligner_hard=self.model_outputs_cache["aligner_mas"],
|
|
aligner_soft=self.model_outputs_cache["aligner_soft"],
|
|
binary_loss_weight=self.encoder_model.binary_loss_weight,
|
|
feats_fake=feats_d_fake,
|
|
feats_real=feats_d_real,
|
|
scores_fake=scores_d_fake,
|
|
spec_slice=mel_slice,
|
|
spec_slice_hat=mel_slice_hat,
|
|
)
|
|
|
|
# compute duration error for logging
|
|
durations_pred = self.model_outputs_cache["durations"]
|
|
durations_target = self.model_outputs_cache["aligner_durations"]
|
|
duration_error = torch.abs(durations_target - durations_pred).sum() / batch["text_lengths"].sum()
|
|
loss_dict["duration_error"] = duration_error
|
|
|
|
return self.model_outputs_cache, loss_dict
|
|
|
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
|
|
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
|
return self.train_step(batch, criterion, optimizer_idx)
|
|
|
|
def _log(self, batch, outputs, name_prefix="train"):
|
|
figures, audios = {}, {}
|
|
|
|
# encoder outputs
|
|
model_outputs = outputs[1]["encoder_outputs"]
|
|
alignments = outputs[1]["alignments"]
|
|
mel_input = batch["mel_input"]
|
|
|
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
|
gt_spec = mel_input[0].data.cpu().numpy()
|
|
align_img = alignments[0].data.cpu().numpy()
|
|
|
|
figures = {
|
|
"prediction": plot_spectrogram(pred_spec, None, output_fig=False),
|
|
"ground_truth": plot_spectrogram(gt_spec, None, output_fig=False),
|
|
"alignment": plot_alignment(align_img, output_fig=False),
|
|
}
|
|
|
|
# plot pitch figures
|
|
if self.args.use_pitch:
|
|
pitch_avg = abs(outputs[1]["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
|
pitch_avg_hat = abs(outputs[1]["pitch_avg"][0, 0].data.cpu().numpy())
|
|
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
|
pitch_figures = {
|
|
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
|
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
|
}
|
|
figures.update(pitch_figures)
|
|
|
|
# plot the attention mask computed from the predicted durations
|
|
if "attn_durations" in outputs[1]:
|
|
alignments_hat = outputs[1]["attn_durations"][0].data.cpu().numpy()
|
|
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
|
|
|
# Sample audio
|
|
encoder_audio = mel_to_wav_numpy(
|
|
mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.__mel_basis, **self.config.audio
|
|
)
|
|
audios[f"{name_prefix}/encoder_audio"] = encoder_audio
|
|
|
|
# vocoder outputs
|
|
y_hat = outputs[1]["model_outputs"]
|
|
y = outputs[1]["waveform_seg"]
|
|
|
|
vocoder_figures = plot_results(y_hat=y_hat, y=y, audio_config=self.config.audio, name_prefix=name_prefix)
|
|
figures.update(vocoder_figures)
|
|
|
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
|
audios[f"{name_prefix}/real_audio"] = sample_voice
|
|
return figures, audios
|
|
|
|
def train_log(
|
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
|
): # pylint: disable=no-self-use, unused-argument
|
|
"""Create visualizations and waveform examples.
|
|
|
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
|
be projected onto Tensorboard.
|
|
|
|
Args:
|
|
batch (Dict): Model inputs used at the previous training step.
|
|
outputs (Dict): Model outputs generated at the previous training step.
|
|
|
|
Returns:
|
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
|
"""
|
|
figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
|
|
logger.train_figures(steps, figures)
|
|
logger.train_audios(steps, audios, self.config.audio.sample_rate)
|
|
|
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
|
figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
|
|
logger.eval_figures(steps, figures)
|
|
logger.eval_audios(steps, audios, self.config.audio.sample_rate)
|
|
|
|
def get_aux_input_from_test_sentences(self, sentence_info):
|
|
if hasattr(self.config, "model_args"):
|
|
config = self.config.model_args
|
|
else:
|
|
config = self.config
|
|
|
|
# extract speaker and language info
|
|
text, speaker_name, style_wav, language_name = None, None, None, None # pylint: disable=unused-variable
|
|
|
|
if isinstance(sentence_info, list):
|
|
if len(sentence_info) == 1:
|
|
text = sentence_info[0]
|
|
elif len(sentence_info) == 2:
|
|
text, speaker_name = sentence_info
|
|
elif len(sentence_info) == 3:
|
|
text, speaker_name, style_wav = sentence_info
|
|
elif len(sentence_info) == 4:
|
|
text, speaker_name, style_wav, language_name = sentence_info
|
|
else:
|
|
text = sentence_info
|
|
|
|
# get speaker id/d_vector
|
|
speaker_id, d_vector, language_id = None, None, None # pylint: disable=unused-variable
|
|
if hasattr(self, "speaker_manager"):
|
|
if config.use_d_vector_file:
|
|
if speaker_name is None:
|
|
d_vector = self.speaker_manager.get_random_d_vector()
|
|
else:
|
|
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
|
|
elif config.use_speaker_embedding:
|
|
if speaker_name is None:
|
|
speaker_id = self.speaker_manager.get_random_speaker_id()
|
|
else:
|
|
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
|
|
|
# get language id
|
|
# if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
|
# language_id = self.language_manager.language_id_mapping[language_name]
|
|
|
|
return {
|
|
"text": text,
|
|
"speaker_id": speaker_id,
|
|
"style_wav": style_wav,
|
|
"d_vector": d_vector,
|
|
"language_id": None,
|
|
"language_name": None,
|
|
}
|
|
|
|
def synthesize(self, text: str, speaker_id, language_id, d_vector):
|
|
# TODO: add language_id
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
# convert text to sequence of token IDs
|
|
text_inputs = np.asarray(
|
|
self.tokenizer.text_to_ids(text, language=language_id),
|
|
dtype=np.int32,
|
|
)
|
|
# pass tensors to backend
|
|
if speaker_id is not None:
|
|
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
|
|
|
if d_vector is not None:
|
|
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
|
|
|
# if language_id is not None:
|
|
# language_id = id_to_torch(language_id, cuda=is_cuda)
|
|
|
|
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
|
text_inputs = text_inputs.unsqueeze(0)
|
|
|
|
# synthesize voice
|
|
outputs = self.inference(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
|
|
|
# collect outputs
|
|
wav = outputs["model_outputs"][0].data.cpu().numpy()
|
|
alignments = outputs["alignments"]
|
|
return_dict = {
|
|
"wav": wav,
|
|
"alignments": alignments,
|
|
"text_inputs": text_inputs,
|
|
"outputs": outputs,
|
|
}
|
|
return return_dict
|
|
|
|
def synthesize_with_gl(self, text: str, speaker_id, language_id, d_vector):
|
|
# TODO: add language_id
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
# convert text to sequence of token IDs
|
|
text_inputs = np.asarray(
|
|
self.tokenizer.text_to_ids(text, language=language_id),
|
|
dtype=np.int32,
|
|
)
|
|
# pass tensors to backend
|
|
if speaker_id is not None:
|
|
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
|
|
|
if d_vector is not None:
|
|
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
|
|
|
# if language_id is not None:
|
|
# language_id = id_to_torch(language_id, cuda=is_cuda)
|
|
|
|
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
|
text_inputs = text_inputs.unsqueeze(0)
|
|
|
|
# synthesize voice
|
|
outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
|
|
|
# collect outputs
|
|
wav = mel_to_wav_numpy(
|
|
mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio
|
|
)
|
|
alignments = outputs["alignments"]
|
|
return_dict = {
|
|
"wav": wav[None, :],
|
|
"alignments": alignments,
|
|
"text_inputs": text_inputs,
|
|
"outputs": outputs,
|
|
}
|
|
return return_dict
|
|
|
|
@torch.no_grad()
|
|
def test_run(self, assets) -> Tuple[Dict, Dict]:
|
|
"""Generic test run for `tts` models used by `Trainer`.
|
|
|
|
You can override this for a different behaviour.
|
|
|
|
Returns:
|
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
|
"""
|
|
print(" | > Synthesizing test sentences.")
|
|
test_audios = {}
|
|
test_figures = {}
|
|
test_sentences = self.config.test_sentences
|
|
for idx, s_info in enumerate(test_sentences):
|
|
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
|
outputs = self.synthesize(
|
|
aux_inputs["text"],
|
|
speaker_id=aux_inputs["speaker_id"],
|
|
d_vector=aux_inputs["d_vector"],
|
|
language_id=aux_inputs["language_id"],
|
|
)
|
|
outputs_gl = self.synthesize_with_gl(
|
|
aux_inputs["text"],
|
|
speaker_id=aux_inputs["speaker_id"],
|
|
d_vector=aux_inputs["d_vector"],
|
|
language_id=aux_inputs["language_id"],
|
|
)
|
|
test_audios["{}-audio".format(idx)] = outputs["wav"].T
|
|
test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T
|
|
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
|
return {"figures": test_figures, "audios": test_audios}
|
|
|
|
def test_log(
|
|
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
|
) -> None:
|
|
logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate)
|
|
logger.test_figures(steps, outputs["figures"])
|
|
|
|
def format_batch(self, batch: Dict) -> Dict:
|
|
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
|
|
speaker_ids = None
|
|
language_ids = None
|
|
d_vectors = None
|
|
|
|
# get numerical speaker ids from speaker names
|
|
if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
|
|
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
|
|
|
|
if speaker_ids is not None:
|
|
speaker_ids = torch.LongTensor(speaker_ids)
|
|
batch["speaker_ids"] = speaker_ids
|
|
|
|
# get d_vectors from audio file names
|
|
if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
|
|
d_vector_mapping = self.speaker_manager.d_vectors
|
|
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
|
|
d_vectors = torch.FloatTensor(d_vectors)
|
|
|
|
# get language ids from language names
|
|
if (
|
|
self.language_manager is not None
|
|
and self.language_manager.language_id_mapping
|
|
and self.args.use_language_embedding
|
|
):
|
|
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
|
|
|
|
if language_ids is not None:
|
|
language_ids = torch.LongTensor(language_ids)
|
|
|
|
batch["language_ids"] = language_ids
|
|
batch["d_vectors"] = d_vectors
|
|
batch["speaker_ids"] = speaker_ids
|
|
return batch
|
|
|
|
def format_batch_on_device(self, batch):
|
|
"""Compute spectrograms on the device."""
|
|
ac = self.config.audio
|
|
|
|
# compute spectrograms
|
|
batch["mel_input"] = wav_to_mel(
|
|
batch["waveform"],
|
|
hop_length=ac.hop_length,
|
|
win_length=ac.win_length,
|
|
n_fft=ac.fft_size,
|
|
num_mels=ac.num_mels,
|
|
sample_rate=ac.sample_rate,
|
|
fmin=ac.mel_fmin,
|
|
fmax=ac.mel_fmax,
|
|
center=False,
|
|
)
|
|
|
|
assert (
|
|
batch["pitch"].shape[2] == batch["mel_input"].shape[2]
|
|
), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}"
|
|
batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int()
|
|
|
|
# zero the padding frames
|
|
batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1)
|
|
batch["mel_input"] = batch["mel_input"].transpose(1, 2)
|
|
return batch
|
|
|
|
def get_data_loader(
|
|
self,
|
|
config: Coqpit,
|
|
assets: Dict,
|
|
is_eval: bool,
|
|
samples: Union[List[Dict], List[List]],
|
|
verbose: bool,
|
|
num_gpus: int,
|
|
rank: int = None,
|
|
) -> "DataLoader":
|
|
if is_eval and not config.run_eval:
|
|
loader = None
|
|
else:
|
|
# init dataloader
|
|
dataset = ForwardTTSE2eDataset(
|
|
samples=samples,
|
|
audio_config=self.config.audio,
|
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
|
min_text_len=config.min_text_len,
|
|
max_text_len=config.max_text_len,
|
|
min_audio_len=config.min_audio_len,
|
|
max_audio_len=config.max_audio_len,
|
|
phoneme_cache_path=config.phoneme_cache_path,
|
|
precompute_num_workers=config.precompute_num_workers,
|
|
compute_f0=config.compute_f0,
|
|
f0_cache_path=config.f0_cache_path,
|
|
verbose=verbose,
|
|
tokenizer=self.tokenizer,
|
|
start_by_longest=config.start_by_longest,
|
|
)
|
|
|
|
# wait all the DDP process to be ready
|
|
if num_gpus > 1:
|
|
dist.barrier()
|
|
|
|
# sort input sequences from short to long
|
|
dataset.preprocess_samples()
|
|
|
|
# get samplers
|
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
|
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
|
shuffle=False, # shuffle is done in the dataset.
|
|
drop_last=False, # setting this False might cause issues in AMP training.
|
|
sampler=sampler,
|
|
collate_fn=dataset.collate_fn,
|
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
return loader
|
|
|
|
def get_criterion(self):
|
|
return [VitsDiscriminatorLoss(self.config), ForwardTTSE2eLoss(self.config)]
|
|
|
|
def get_optimizer(self) -> List:
|
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
|
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
|
Returns:
|
|
List: optimizers.
|
|
"""
|
|
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
|
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
|
optimizer1 = get_optimizer(
|
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
|
)
|
|
return [optimizer0, optimizer1]
|
|
|
|
def get_lr(self) -> List:
|
|
"""Set the initial learning rates for each optimizer.
|
|
|
|
Returns:
|
|
List: learning rates for each optimizer.
|
|
"""
|
|
return [self.config.lr_disc, self.config.lr_gen]
|
|
|
|
def get_scheduler(self, optimizer) -> List:
|
|
"""Set the schedulers for each optimizer.
|
|
|
|
Args:
|
|
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
|
|
|
Returns:
|
|
List: Schedulers, one for each optimizer.
|
|
"""
|
|
scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
|
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
|
return [scheduler_D, scheduler_G]
|
|
|
|
def on_train_step_start(self, trainer):
|
|
"""Schedule binary loss weight."""
|
|
self.encoder_model.config.binary_loss_warmup_epochs = self.config.binary_loss_warmup_epochs
|
|
self.encoder_model.on_train_step_start(trainer)
|
|
|
|
def on_init_start(self, trainer: "Trainer"):
|
|
self.__mel_basis = build_mel_basis(
|
|
sample_rate=self.config.audio.sample_rate,
|
|
fft_size=self.config.audio.fft_size,
|
|
num_mels=self.config.audio.num_mels,
|
|
mel_fmax=self.config.audio.mel_fmax,
|
|
mel_fmin=self.config.audio.mel_fmin,
|
|
)
|
|
|
|
@staticmethod
|
|
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
|
"""Initiate model from config
|
|
|
|
Args:
|
|
config (ForwardTTSE2eConfig): Model config.
|
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
|
Defaults to None.
|
|
"""
|
|
from TTS.utils.audio.processor import AudioProcessor
|
|
|
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
|
# language_manager = LanguageManager.init_from_config(config)
|
|
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
|
|
|
def load_checkpoint(
|
|
self, config, checkpoint_path, eval=False
|
|
):
|
|
"""Load model from a checkpoint created by the 👟"""
|
|
# pylint: disable=unused-argument, redefined-builtin
|
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
|
self.load_state_dict(state["model"])
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
|
|
def get_state_dict(self):
|
|
"""Custom state dict of the model with all the necessary components for inference."""
|
|
save_state = {
|
|
"config": self.config.to_dict(),
|
|
"args": self.args.to_dict(),
|
|
"model": self.state_dict
|
|
}
|
|
|
|
if hasattr(self, "emb_g"):
|
|
save_state["speaker_ids"] = self.speaker_manager.speaker_ids
|
|
|
|
if self.args.use_d_vector_file:
|
|
# TODO: implement saving of d_vectors
|
|
...
|
|
return save_state
|
|
|
|
def save(self, config, checkpoint_path):
|
|
"""Save model to a file."""
|
|
save_state = self.get_state_dict(config, checkpoint_path)
|
|
torch.save(save_state, checkpoint_path)
|