diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7dac1bb9..b7766e92 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,24 +1,31 @@ +import collections import math +import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist import torchaudio from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -27,6 +34,263 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +############################## +# IO / Feature extraction +############################## + +hann_window = {} +mel_basis = {} + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +class VitsDataset(TTSDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + + def __getitem__(self, idx): + item = self.samples[idx] + + text, wav_file, speaker_name, language_name, _ = _parse_sample(item) + raw_text = text + + wav, sr = load_audio(wav_file) + wav_filename = os.path.basename(wav_file) + + token_ids = self.get_token_ids(idx, text) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": speaker_name, + "language_name": language_name, + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# MODEL DEFINITION +############################## + @dataclass class VitsArgs(Coqpit): @@ -268,38 +532,20 @@ class Vits(BaseTTS): Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. Examples: - Init only model layers. - >>> from TTS.tts.configs.vits_config import VitsConfig >>> from TTS.tts.models.vits import Vits >>> config = VitsConfig() >>> model = Vits(config) - - Fully init a model ready for action. All the class attributes and class members - (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. - - >>> from TTS.tts.configs.vits_config import VitsConfig - >>> from TTS.tts.models.vits import Vits - >>> config = VitsConfig() - >>> model = Vits.init_from_config(config) """ - # pylint: disable=dangerous-default-value - - def __init__( - self, + def __init__(self, config: Coqpit, ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + language_manager: LanguageManager = None,): - super().__init__(config, ap, tokenizer, speaker_manager) - - self.END2END = True - self.speaker_manager = speaker_manager - self.language_manager = language_manager + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) self.init_multispeaker(config) self.init_multilingual(config) @@ -363,10 +609,6 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, ) - upsample_rate = math.prod(self.args.upsample_rates_decoder) - assert ( - upsample_rate == self.config.audio.hop_length - ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -398,6 +640,7 @@ class Vits(BaseTTS): """ self.embedded_speaker_dim = 0 self.num_speakers = self.args.num_speakers + self.audio_transform = None if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers @@ -428,8 +671,11 @@ class Vits(BaseTTS): orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], ) - else: - self.audio_transform = None + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -463,6 +709,35 @@ class Vits(BaseTTS): self.embedded_language_dim = 0 self.emb_l = None + def get_aux_input(self, aux_input: Dict): + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" @@ -483,58 +758,6 @@ class Vits(BaseTTS): return sid, g, lid - def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) - return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} - - def get_aux_input_from_test_sentences(self, sentence_info): - if hasattr(self.config, "model_args"): - config = self.config.model_args - else: - config = self.config - - # extract speaker and language info - text, speaker_name, style_wav, language_name = None, None, None, None - - if isinstance(sentence_info, list): - if len(sentence_info) == 1: - text = sentence_info[0] - elif len(sentence_info) == 2: - text, speaker_name = sentence_info - elif len(sentence_info) == 3: - text, speaker_name, style_wav = sentence_info - elif len(sentence_info) == 4: - text, speaker_name, style_wav, language_name = sentence_info - else: - text = sentence_info - - # get speaker id/d_vector - speaker_id, d_vector, language_id = None, None, None - if hasattr(self, "speaker_manager"): - if config.use_d_vector_file: - if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() - else: - d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) - elif config.use_speaker_embedding: - if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() - else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] - - # get language id - if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] - - return { - "text": text, - "speaker_id": speaker_id, - "style_wav": style_wav, - "d_vector": d_vector, - "language_id": language_id, - "language_name": language_name, - } - def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) speaker_ids = aux_input.get("speaker_ids", None) @@ -611,7 +834,7 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - - waveform: :math:`[B, T_wav, 1]` + - waveform: :math:`[B, 1, T_wav]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` @@ -656,13 +879,14 @@ class Vits(BaseTTS): logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, + pad_short = True ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -694,6 +918,7 @@ class Vits(BaseTTS): "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, } ) return outputs @@ -798,30 +1023,6 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def _freeze_layers(self): - if self.args.freeze_encoder: - for param in self.text_encoder.parameters(): - param.requires_grad = False - - if hasattr(self, "emb_l"): - for param in self.emb_l.parameters(): - param.requires_grad = False - - if self.args.freeze_PE: - for param in self.posterior_encoder.parameters(): - param.requires_grad = False - - if self.args.freeze_DP: - for param in self.duration_predictor.parameters(): - param.requires_grad = False - - if self.args.freeze_flow_decoder: - for param in self.flow.parameters(): - param.requires_grad = False - - if self.args.freeze_waveform_decoder: - for param in self.waveform_decoder.parameters(): - param.requires_grad = False def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. @@ -835,91 +1036,101 @@ class Vits(BaseTTS): Tuple[Dict, Dict]: Model ouputs and computed losses. """ - # pylint: disable=attribute-defined-outside-init - if optimizer_idx not in [0, 1]: - raise ValueError(" [!] Unexpected `optimizer_idx`.") - self._freeze_layers() + mel_lens = batch["mel_lens"] + if optimizer_idx == 0: - text_input = batch["text_input"] - text_lengths = batch["text_lengths"] - mel_lengths = batch["mel_lengths"] - linear_input = batch["linear_input"] + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + spec_lens = batch["spec_lens"] + d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] language_ids = batch["language_ids"] waveform = batch["waveform"] - # if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0: - # breakpoint() - # generator pass outputs = self.forward( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - waveform.transpose(1, 2), + tokens, + token_lenghts, + spec, + spec_lens, + waveform, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) - # cache tensors for the discriminator - self.y_disc_cache = outputs["model_outputs"] - self.wav_seg_disc_cache = outputs["waveform_seg"] - - # compute discriminator scores and features - outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( - outputs["model_outputs"], outputs["waveform_seg"] - ) - - # compute losses - with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - waveform_hat=outputs["model_outputs"].float(), - waveform=outputs["waveform_seg"].float(), - z_p=outputs["z_p"].float(), - logs_q=outputs["logs_q"].float(), - m_p=outputs["m_p"].float(), - logs_p=outputs["logs_p"].float(), - z_len=mel_lengths, - scores_disc_fake=outputs["scores_disc_fake"], - feats_disc_fake=outputs["feats_disc_fake"], - feats_disc_real=outputs["feats_disc_real"], - loss_duration=outputs["loss_duration"], - use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb=outputs["gt_spk_emb"], - syn_spk_emb=outputs["syn_spk_emb"], - ) - - # if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0: - # breakpoint() - - elif optimizer_idx == 1: - # discriminator pass - outputs = {} + # cache tensors for the generator pass + self.model_outputs_cache = outputs # compute scores and features - outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc( - self.y_disc_cache.detach(), self.wav_seg_disc_cache + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] ) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( - outputs["scores_disc_real"], - outputs["scores_disc_fake"], + scores_disc_real, + scores_disc_fake, ) - return outputs, loss_dict + return {}, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True) + mel_slice_hat = wav_to_mel( + y = self.model_outputs_cache["model_outputs"].float(), + n_fft = self.config.audio.fft_size, + sample_rate = self.config.audio.sample_rate, + num_mels = self.config.audio.num_mels, + hop_length = self.config.audio.hop_length, + win_length = self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p= self.model_outputs_cache["z_p"].float(), + logs_q= self.model_outputs_cache["logs_q"].float(), + m_p= self.model_outputs_cache["m_p"].float(), + logs_p= self.model_outputs_cache["logs_p"].float(), + z_len=mel_lens, + scores_disc_fake= scores_disc_fake, + feats_disc_fake= feats_disc_fake, + feats_disc_real= feats_disc_real, + loss_duration= self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb= self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb= self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use - y_hat = outputs[0]["model_outputs"] - y = outputs[0]["waveform_seg"] + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name_prefix}/audio": sample_voice} - alignments = outputs[0]["alignments"] + alignments = outputs[1]["alignments"] align_img = alignments[0].data.cpu().numpy().T figures.update( @@ -927,7 +1138,6 @@ class Vits(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } ) - return figures, audios def train_log( @@ -948,7 +1158,7 @@ class Vits(BaseTTS): """ figures, audios = self._log(self.ap, batch, outputs, "train") logger.train_figures(steps, figures) - logger.train_figures(steps, audios, self.ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): @@ -959,6 +1169,54 @@ class Vits(BaseTTS): logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. @@ -973,56 +1231,187 @@ class Vits(BaseTTS): test_figures = {} test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): - try: - aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( - self, - aux_inputs["text"], - self.config, - "cuda" in str(next(self.parameters()).device), - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - language_id=aux_inputs["language_id"], - use_griffin_lim=True, - do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) return {"figures": test_figures, "audios": test_audios} def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_figures(steps, outputs["figures"]) + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.d_vectors + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding: + language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + # compute spectrograms + batch["spec"] = wav_to_spec( + batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False + ) + batch["mel"] = spec_to_mel( + spec = batch["spec"], + n_fft = ac.fft_size, + num_mels = ac.num_mels, + sample_rate = ac.sample_rate, + fmin = ac.mel_fmin, + fmax = ac.mel_fmax, + ) + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + speaker_id_mapping = None + d_vector_mapping = None + if hasattr(self, "speaker_manager") and self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None + + # setup multi-lingual attributes + language_id_mapping = None + if hasattr(self, "language_manager"): + language_id_mapping = ( + self.language_manager.language_id_mapping if self.args.use_language_embedding else None + ) + + # init dataloader + dataset = VitsDataset( + samples=samples, + # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # sampler for DDP + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + + # Weighted samplers + # TODO: make this DDP amenable + assert not ( + num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) + ), "language_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) + ), "speaker_weighted_sampler is not supported with DistributedSampler" + + if sampler is None: + if getattr(config, "use_language_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_language_weighted_sampler(dataset.samples) + elif getattr(config, "use_speaker_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_speaker_weighted_sampler(dataset.samples) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. - Returns: List: optimizers. """ - gen_parameters = chain( - self.text_encoder.parameters(), - self.posterior_encoder.parameters(), - self.flow.parameters(), - self.duration_predictor.parameters(), - self.waveform_decoder.parameters(), - ) - # add the speaker embedding layer - if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file: - gen_parameters = chain(gen_parameters, self.emb_g.parameters()) - # add the language embedding layer - if hasattr(self, "emb_l") and self.args.use_language_embedding: - gen_parameters = chain(gen_parameters, self.emb_l.parameters()) + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - optimizer0 = get_optimizer( + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) - optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) return [optimizer0, optimizer1] def get_lr(self) -> List: @@ -1031,7 +1420,7 @@ class Vits(BaseTTS): Returns: List: learning rates for each optimizer. """ - return [self.config.lr_gen, self.config.lr_disc] + return [self.config.lr_disc, self.config.lr_gen] def get_scheduler(self, optimizer) -> List: """Set the schedulers for each optimizer. @@ -1042,9 +1431,9 @@ class Vits(BaseTTS): Returns: List: Schedulers, one for each optimizer. """ - scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler0, scheduler1] + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] def get_criterion(self): """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in @@ -1054,10 +1443,14 @@ class Vits(BaseTTS): VitsGeneratorLoss, ) - return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] def load_checkpoint( - self, config, checkpoint_path, eval=False + self, + config, + checkpoint_path, + eval=False, + strict=True, ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = torch.load(checkpoint_path, map_location=torch.device("cpu")) @@ -1066,15 +1459,16 @@ class Vits(BaseTTS): # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers - if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: - num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0] print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") - emb_g = state["model"]["emb_g.weight"] + emb_g = state["model"]["vits.emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) - state["model"]["emb_g.weight"] = emb_g + state["model"]["vits.emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) - self.load_state_dict(state["model"], strict=False) if eval: self.eval() assert not self.training @@ -1090,12 +1484,21 @@ class Vits(BaseTTS): """ from TTS.utils.audio import AudioProcessor + upsample_rate = math.prod(config.model_args.upsample_rates_decoder) + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) +################################## +# VITS CHARACTERS +################################## + class VitsCharacters(BaseCharacters): """Characters class for VITs model for compatibility with pre-trained models"""