From a5f5ebae7ef60ec9f99adfd447b189e0b6eae77b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 28 Mar 2022 22:04:49 +0000 Subject: [PATCH] Fix lint checks --- TTS/tts/models/vits.py | 39 +- TTS/tts/models/vits_old_way.py | 1590 ----------------- ...est_vits_multilingual_speaker_emb_train.py | 7 +- 3 files changed, 14 insertions(+), 1622 deletions(-) delete mode 100644 TTS/tts/models/vits_old_way.py diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e09c3667..0a65632a 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -632,7 +632,9 @@ class Vits(BaseTTS): if self.args.TTS_part_sample_rate: self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate - self.audio_resampler = torchaudio.transforms.Resample(orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate) + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate + ) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -818,7 +820,6 @@ class Vits(BaseTTS): y: torch.tensor, y_lengths: torch.tensor, waveform: torch.tensor, - waveform_spec: torch.tensor, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, ) -> Dict: """Forward pass of the model. @@ -887,19 +888,14 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) - wav_seg2 = segment( - waveform_spec, - slice_ids * self.config.audio.hop_length, - self.spec_segment_size * self.config.audio.hop_length, - pad_short=True, - ) if self.args.TTS_part_sample_rate: slice_ids = slice_ids * int(self.interpolate_factor) spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) if self.args.interpolate_z: - z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable + z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable z_slice = torch.nn.functional.interpolate( - z_slice, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0) + z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest" + ).squeeze(0) else: spec_segment_size = self.spec_segment_size @@ -912,11 +908,6 @@ class Vits(BaseTTS): pad_short=True, ) - # print(o.shape, wav_seg.shape, spec_segment_size, self.spec_segment_size) - # self.ap.save_wav(wav_seg[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_44khz.wav", sr=self.ap.sample_rate) - # self.ap.save_wav(wav_seg2[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_22khz.wav", sr=self.args.TTS_part_sample_rate) - # self.ap.save_wav(o[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_gen_44khz_test_model_output.wav", sr=self.ap.sample_rate) - if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0) @@ -1021,10 +1012,11 @@ class Vits(BaseTTS): z = self.flow(z_p, y_mask, g=g, reverse=True) if self.args.TTS_part_sample_rate and self.args.interpolate_z: - z = z.unsqueeze(0) # pylint: disable=not-callable - z = torch.nn.functional.interpolate( - z, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0) - y_mask = sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) # [B, 1, T_dec_resampled] + z = z.unsqueeze(0) # pylint: disable=not-callable + z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0) + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) @@ -1101,7 +1093,6 @@ class Vits(BaseTTS): self._freeze_layers() - mel_lens = batch["mel_lens"] spec_lens = batch["spec_lens"] if optimizer_idx == 0: @@ -1121,7 +1112,6 @@ class Vits(BaseTTS): spec, spec_lens, waveform, - batch["waveform_spec"], aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) @@ -1146,7 +1136,7 @@ class Vits(BaseTTS): # compute melspec segment with autocast(enabled=False): - + if self.args.TTS_part_sample_rate: spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) else: @@ -1380,7 +1370,6 @@ class Vits(BaseTTS): else: spec_mel = batch["spec"] - batch["mel"] = spec_to_mel( spec=spec_mel, n_fft=ac.fft_size, @@ -1390,15 +1379,13 @@ class Vits(BaseTTS): fmax=ac.mel_fmax, ) - batch["waveform_spec"] = wav - if not self.args.TTS_part_sample_rate: assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" # compute spectrogram frame lengths batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() - + if not self.args.TTS_part_sample_rate: assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 diff --git a/TTS/tts/models/vits_old_way.py b/TTS/tts/models/vits_old_way.py deleted file mode 100644 index 6d9654d1..00000000 --- a/TTS/tts/models/vits_old_way.py +++ /dev/null @@ -1,1590 +0,0 @@ -import math -import os -from dataclasses import dataclass, field, replace -from itertools import chain -from typing import Dict, List, Tuple, Union - -import torch -import torch.distributed as dist -import torchaudio -from coqpit import Coqpit -from librosa.filters import mel as librosa_mel_fn -from torch import nn -from torch.cuda.amp.autocast_mode import autocast -from torch.nn import functional as F -from torch.utils.data import DataLoader -from trainer.trainer_utils import get_optimizer, get_scheduler - -from TTS.tts.configs.shared_configs import CharactersConfig -from TTS.tts.datasets.dataset import TTSDataset, _parse_sample -from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor -from TTS.tts.layers.vits.discriminator import VitsDiscriminator -from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder -from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor -from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations -from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment -from TTS.vocoder.models.hifigan_generator import HifiganGenerator -from TTS.vocoder.utils.generic_utils import plot_results - -############################## -# IO / Feature extraction -############################## - -# pylint: disable=global-statement -hann_window = {} -mel_basis = {} - - -def load_audio(file_path): - """Load the audio file normalized in [-1, 1] - - Return Shapes: - - x: :math:`[1, T]` - """ - x, sr = torchaudio.load(file_path) - assert (x > 1).sum() + (x < -1).sum() == 0 - return x, sr - - -def _amp_to_db(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def _db_to_amp(x, C=1): - return torch.exp(x) / C - - -def amp_to_db(magnitudes): - output = _amp_to_db(magnitudes) - return output - - -def db_to_amp(magnitudes): - output = _db_to_amp(magnitudes) - return output - - -def wav_to_spec(y, n_fft, hop_length, win_length, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): - """ - Args Shapes: - - spec : :math:`[B,C,T]` - - Return Shapes: - - mel : :math:`[B,C,T]` - """ - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - mel = torch.matmul(mel_basis[fmax_dtype_device], spec) - mel = amp_to_db(mel) - return mel - - -def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = amp_to_db(spec) - return spec - - -############################## -# DATASET -############################## - - -class VitsDataset(TTSDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.pad_id = self.tokenizer.characters.pad_id - - def __getitem__(self, idx): - item = self.samples[idx] - raw_text = item["text"] - - wav, _ = load_audio(item["audio_file"]) - wav_filename = os.path.basename(item["audio_file"]) - - token_ids = self.get_token_ids(idx, item["text"]) - - # after phonemization the text length may change - # this is a shameful 🤭 hack to prevent longer phonemes - # TODO: find a better fix - if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: - self.rescue_item_idx += 1 - return self.__getitem__(self.rescue_item_idx) - - return { - "raw_text": raw_text, - "token_ids": token_ids, - "token_len": len(token_ids), - "wav": wav, - "wav_file": wav_filename, - "speaker_name": item["speaker_name"], - "language_name": item["language"], - } - - @property - def lengths(self): - lens = [] - for item in self.samples: - _, wav_file, *_ = _parse_sample(item) - audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio - lens.append(audio_len) - return lens - - def collate_fn(self, batch): - """ - Return Shapes: - - tokens: :math:`[B, T]` - - token_lens :math:`[B]` - - token_rel_lens :math:`[B]` - - waveform: :math:`[B, 1, T]` - - waveform_lens: :math:`[B]` - - waveform_rel_lens: :math:`[B]` - - speaker_names: :math:`[B]` - - language_names: :math:`[B]` - - audiofile_paths: :math:`[B]` - - raw_texts: :math:`[B]` - """ - # convert list of dicts to dict of lists - B = len(batch) - batch = {k: [dic[k] for dic in batch] for k in batch[0]} - - _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True - ) - - max_text_len = max([len(x) for x in batch["token_ids"]]) - token_lens = torch.LongTensor(batch["token_len"]) - token_rel_lens = token_lens / token_lens.max() - - wav_lens = [w.shape[1] for w in batch["wav"]] - wav_lens = torch.LongTensor(wav_lens) - wav_lens_max = torch.max(wav_lens) - wav_rel_lens = wav_lens / wav_lens_max - - token_padded = torch.LongTensor(B, max_text_len) - wav_padded = torch.FloatTensor(B, 1, wav_lens_max) - token_padded = token_padded.zero_() + self.pad_id - wav_padded = wav_padded.zero_() + self.pad_id - for i in range(len(ids_sorted_decreasing)): - token_ids = batch["token_ids"][i] - token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) - - wav = batch["wav"][i] - wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) - - return { - "tokens": token_padded, - "token_lens": token_lens, - "token_rel_lens": token_rel_lens, - "waveform": wav_padded, # (B x T) - "waveform_lens": wav_lens, # (B) - "waveform_rel_lens": wav_rel_lens, - "speaker_names": batch["speaker_name"], - "language_names": batch["language_name"], - "audio_files": batch["wav_file"], - "raw_text": batch["raw_text"], - } - - -############################## -# MODEL DEFINITION -############################## - - -@dataclass -class VitsArgs(Coqpit): - """VITS model arguments. - - Args: - - num_chars (int): - Number of characters in the vocabulary. Defaults to 100. - - out_channels (int): - Number of output channels of the decoder. Defaults to 513. - - spec_segment_size (int): - Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. - - hidden_channels (int): - Number of hidden channels of the model. Defaults to 192. - - hidden_channels_ffn_text_encoder (int): - Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. - - num_heads_text_encoder (int): - Number of attention heads of the text encoder transformer. Defaults to 2. - - num_layers_text_encoder (int): - Number of transformer layers in the text encoder. Defaults to 6. - - kernel_size_text_encoder (int): - Kernel size of the text encoder transformer FFN layers. Defaults to 3. - - dropout_p_text_encoder (float): - Dropout rate of the text encoder. Defaults to 0.1. - - dropout_p_duration_predictor (float): - Dropout rate of the duration predictor. Defaults to 0.1. - - kernel_size_posterior_encoder (int): - Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. - - dilatation_posterior_encoder (int): - Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. - - num_layers_posterior_encoder (int): - Number of posterior encoder's WaveNet layers. Defaults to 16. - - kernel_size_flow (int): - Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. - - dilatation_flow (int): - Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. - - num_layers_flow (int): - Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. - - resblock_type_decoder (str): - Type of the residual block in the decoder network. Defaults to "1". - - resblock_kernel_sizes_decoder (List[int]): - Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. - - resblock_dilation_sizes_decoder (List[List[int]]): - Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. - - upsample_rates_decoder (List[int]): - Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these - values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. - - upsample_initial_channel_decoder (int): - Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. - - upsample_kernel_sizes_decoder (List[int]): - Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. - - use_sdp (bool): - Use Stochastic Duration Predictor. Defaults to True. - - noise_scale (float): - Noise scale used for the sample noise tensor in training. Defaults to 1.0. - - inference_noise_scale (float): - Noise scale used for the sample noise tensor in inference. Defaults to 0.667. - - length_scale (float): - Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. - - noise_scale_dp (float): - Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. - - inference_noise_scale_dp (float): - Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. - - max_inference_len (int): - Maximum inference length to limit the memory use. Defaults to None. - - init_discriminator (bool): - Initialize the disciminator network if set True. Set False for inference. Defaults to True. - - use_spectral_norm_disriminator (bool): - Use spectral normalization over weight norm in the discriminator. Defaults to False. - - use_speaker_embedding (bool): - Enable/Disable speaker embedding for multi-speaker models. Defaults to False. - - num_speakers (int): - Number of speakers for the speaker embedding layer. Defaults to 0. - - speakers_file (str): - Path to the speaker mapping file for the Speaker Manager. Defaults to None. - - speaker_embedding_channels (int): - Number of speaker embedding channels. Defaults to 256. - - use_d_vector_file (bool): - Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. - - d_vector_file (str): - Path to the file including pre-computed speaker embeddings. Defaults to None. - - d_vector_dim (int): - Number of d-vector channels. Defaults to 0. - - detach_dp_input (bool): - Detach duration predictor's input from the network for stopping the gradients. Defaults to True. - - use_language_embedding (bool): - Enable/Disable language embedding for multilingual models. Defaults to False. - - embedded_language_dim (int): - Number of language embedding channels. Defaults to 4. - - num_languages (int): - Number of languages for the language embedding layer. Defaults to 0. - - language_ids_file (str): - Path to the language mapping file for the Language Manager. Defaults to None. - - use_speaker_encoder_as_loss (bool): - Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. - - speaker_encoder_config_path (str): - Path to the file speaker encoder config file, to use for SCL. Defaults to "". - - speaker_encoder_model_path (str): - Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". - - condition_dp_on_speaker (bool): - Condition the duration predictor on the speaker embedding. Defaults to True. - - freeze_encoder (bool): - Freeze the encoder weigths during training. Defaults to False. - - freeze_DP (bool): - Freeze the duration predictor weigths during training. Defaults to False. - - freeze_PE (bool): - Freeze the posterior encoder weigths during training. Defaults to False. - - freeze_flow_encoder (bool): - Freeze the flow encoder weigths during training. Defaults to False. - - freeze_waveform_decoder (bool): - Freeze the waveform decoder weigths during training. Defaults to False. - """ - - num_chars: int = 100 - out_channels: int = 513 - spec_segment_size: int = 32 - hidden_channels: int = 192 - hidden_channels_ffn_text_encoder: int = 768 - num_heads_text_encoder: int = 2 - num_layers_text_encoder: int = 6 - kernel_size_text_encoder: int = 3 - dropout_p_text_encoder: float = 0.1 - dropout_p_duration_predictor: float = 0.5 - kernel_size_posterior_encoder: int = 5 - dilation_rate_posterior_encoder: int = 1 - num_layers_posterior_encoder: int = 16 - kernel_size_flow: int = 5 - dilation_rate_flow: int = 1 - num_layers_flow: int = 4 - resblock_type_decoder: str = "1" - resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) - resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) - upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) - upsample_initial_channel_decoder: int = 512 - upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) - use_sdp: bool = True - noise_scale: float = 1.0 - inference_noise_scale: float = 0.667 - length_scale: float = 1 - noise_scale_dp: float = 1.0 - inference_noise_scale_dp: float = 1.0 - max_inference_len: int = None - init_discriminator: bool = True - use_spectral_norm_disriminator: bool = False - use_speaker_embedding: bool = False - num_speakers: int = 0 - speakers_file: str = None - d_vector_file: str = None - speaker_embedding_channels: int = 256 - use_d_vector_file: bool = False - d_vector_dim: int = 0 - detach_dp_input: bool = True - use_language_embedding: bool = False - embedded_language_dim: int = 4 - num_languages: int = 0 - language_ids_file: str = None - use_speaker_encoder_as_loss: bool = False - speaker_encoder_config_path: str = "" - speaker_encoder_model_path: str = "" - condition_dp_on_speaker: bool = True - freeze_encoder: bool = False - freeze_DP: bool = False - freeze_PE: bool = False - freeze_flow_decoder: bool = False - freeze_waveform_decoder: bool = False - TTS_part_sample_rate: int = None - interpolate_z: bool = True - detach_z_vocoder: bool = False - - -class Vits(BaseTTS): - """VITS TTS model - - Paper:: - https://arxiv.org/pdf/2106.06103.pdf - - Paper Abstract:: - Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel - sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. - In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than - current two-stage models. Our method adopts variational inference augmented with normalizing flows and - an adversarial training process, which improves the expressive power of generative modeling. We also propose a - stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the - uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the - natural one-to-many relationship in which a text input can be spoken in multiple ways - with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) - on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly - available TTS systems and achieves a MOS comparable to ground truth. - - Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. - - Examples: - >>> from TTS.tts.configs.vits_config import VitsConfig - >>> from TTS.tts.models.vits import Vits - >>> config = VitsConfig() - >>> model = Vits(config) - """ - - def __init__( - self, - config: Coqpit, - ap: "AudioProcessor" = None, - tokenizer: "TTSTokenizer" = None, - speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): - - super().__init__(config, ap, tokenizer, speaker_manager, language_manager) - - self.init_multispeaker(config) - self.init_multilingual(config) - - self.length_scale = self.args.length_scale - self.noise_scale = self.args.noise_scale - self.inference_noise_scale = self.args.inference_noise_scale - self.inference_noise_scale_dp = self.args.inference_noise_scale_dp - self.noise_scale_dp = self.args.noise_scale_dp - self.max_inference_len = self.args.max_inference_len - self.spec_segment_size = self.args.spec_segment_size - - self.text_encoder = TextEncoder( - self.args.num_chars, - self.args.hidden_channels, - self.args.hidden_channels, - self.args.hidden_channels_ffn_text_encoder, - self.args.num_heads_text_encoder, - self.args.num_layers_text_encoder, - self.args.kernel_size_text_encoder, - self.args.dropout_p_text_encoder, - language_emb_dim=self.embedded_language_dim, - ) - - self.posterior_encoder = PosteriorEncoder( - self.args.out_channels, - self.args.hidden_channels, - self.args.hidden_channels, - kernel_size=self.args.kernel_size_posterior_encoder, - dilation_rate=self.args.dilation_rate_posterior_encoder, - num_layers=self.args.num_layers_posterior_encoder, - cond_channels=self.embedded_speaker_dim, - ) - - self.flow = ResidualCouplingBlocks( - self.args.hidden_channels, - self.args.hidden_channels, - kernel_size=self.args.kernel_size_flow, - dilation_rate=self.args.dilation_rate_flow, - num_layers=self.args.num_layers_flow, - cond_channels=self.embedded_speaker_dim, - ) - - if self.args.use_sdp: - self.duration_predictor = StochasticDurationPredictor( - self.args.hidden_channels, - 192, - 3, - self.args.dropout_p_duration_predictor, - 4, - cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, - language_emb_dim=self.embedded_language_dim, - ) - else: - self.duration_predictor = DurationPredictor( - self.args.hidden_channels, - 256, - 3, - self.args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim, - language_emb_dim=self.embedded_language_dim, - ) - - self.waveform_decoder = HifiganGenerator( - self.args.hidden_channels, - 1, - self.args.resblock_type_decoder, - self.args.resblock_dilation_sizes_decoder, - self.args.resblock_kernel_sizes_decoder, - self.args.upsample_kernel_sizes_decoder, - self.args.upsample_initial_channel_decoder, - self.args.upsample_rates_decoder, - inference_padding=0, - cond_channels=self.embedded_speaker_dim, - conv_pre_weight_norm=False, - conv_post_weight_norm=False, - conv_post_bias=False, - ) - - if self.args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) - - if self.args.TTS_part_sample_rate: - self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate - self.audio_resampler = torchaudio.transforms.Resample(orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate) - - def init_multispeaker(self, config: Coqpit): - """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer - or with external `d_vectors` computed from a speaker encoder model. - - You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. - - Args: - config (Coqpit): Model configuration. - data (List, optional): Dataset items to infer number of speakers. Defaults to None. - """ - self.embedded_speaker_dim = 0 - self.num_speakers = self.args.num_speakers - self.audio_transform = None - - if self.speaker_manager: - self.num_speakers = self.speaker_manager.num_speakers - - if self.args.use_speaker_embedding: - self._init_speaker_embedding() - - if self.args.use_d_vector_file: - self._init_d_vector() - - # TODO: make this a function - if self.args.use_speaker_encoder_as_loss: - if self.speaker_manager.speaker_encoder is None and ( - not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path - ): - raise RuntimeError( - " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" - ) - - self.speaker_manager.speaker_encoder.eval() - print(" > External Speaker Encoder Loaded !!") - - if ( - hasattr(self.speaker_manager.speaker_encoder, "audio_config") - and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] - ): - self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.audio_config["sample_rate"], - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) - # pylint: disable=W0101,W0105 - self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.config.audio.sample_rate, - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) - - def _init_speaker_embedding(self): - # pylint: disable=attribute-defined-outside-init - if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") - self.embedded_speaker_dim = self.args.speaker_embedding_channels - self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) - - def _init_d_vector(self): - # pylint: disable=attribute-defined-outside-init - if hasattr(self, "emb_g"): - raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") - self.embedded_speaker_dim = self.args.d_vector_dim - - def init_multilingual(self, config: Coqpit): - """Initialize multilingual modules of a model. - - Args: - config (Coqpit): Model configuration. - """ - if self.args.language_ids_file is not None: - self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) - - if self.args.use_language_embedding and self.language_manager: - print(" > initialization of language-embedding layers.") - self.num_languages = self.language_manager.num_languages - self.embedded_language_dim = self.args.embedded_language_dim - self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) - torch.nn.init.xavier_uniform_(self.emb_l.weight) - else: - self.embedded_language_dim = 0 - - def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) - return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} - - def _freeze_layers(self): - if self.args.freeze_encoder: - for param in self.text_encoder.parameters(): - param.requires_grad = False - - if hasattr(self, "emb_l"): - for param in self.emb_l.parameters(): - param.requires_grad = False - - if self.args.freeze_PE: - for param in self.posterior_encoder.parameters(): - param.requires_grad = False - - if self.args.freeze_DP: - for param in self.duration_predictor.parameters(): - param.requires_grad = False - - if self.args.freeze_flow_decoder: - for param in self.flow.parameters(): - param.requires_grad = False - - if self.args.freeze_waveform_decoder: - for param in self.waveform_decoder.parameters(): - param.requires_grad = False - - @staticmethod - def _set_cond_input(aux_input: Dict): - """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid = None, None, None - if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: - sid = aux_input["speaker_ids"] - if sid.ndim == 0: - sid = sid.unsqueeze_(0) - if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: - g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) - if g.ndim == 2: - g = g.unsqueeze_(0) - - if "language_ids" in aux_input and aux_input["language_ids"] is not None: - lid = aux_input["language_ids"] - if lid.ndim == 0: - lid = lid.unsqueeze_(0) - - return sid, g, lid - - def _set_speaker_input(self, aux_input: Dict): - d_vectors = aux_input.get("d_vectors", None) - speaker_ids = aux_input.get("speaker_ids", None) - - if d_vectors is not None and speaker_ids is not None: - raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") - - if speaker_ids is not None and not hasattr(self, "emb_g"): - raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") - - g = speaker_ids if speaker_ids is not None else d_vectors - return g - - def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): - # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - with torch.no_grad(): - o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] - - # duration predictor - attn_durations = attn.sum(3) - if self.args.use_sdp: - loss_duration = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - attn_durations, - g=g.detach() if self.args.detach_dp_input and g is not None else g, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = loss_duration / torch.sum(x_mask) - else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask - log_durations = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - g=g.detach() if self.args.detach_dp_input and g is not None else g, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration - return outputs, attn - - def forward( # pylint: disable=dangerous-default-value - self, - x: torch.tensor, - x_lengths: torch.tensor, - y: torch.tensor, - y_lengths: torch.tensor, - waveform: torch.tensor, - aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, - ) -> Dict: - """Forward pass of the model. - - Args: - x (torch.tensor): Batch of input character sequence IDs. - x_lengths (torch.tensor): Batch of input character sequence lengths. - y (torch.tensor): Batch of input spectrograms. - y_lengths (torch.tensor): Batch of input spectrogram lengths. - waveform (torch.tensor): Batch of ground truth waveforms per sample. - aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. - Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. - - Returns: - Dict: model outputs keyed by the output name. - - Shapes: - - x: :math:`[B, T_seq]` - - x_lengths: :math:`[B]` - - y: :math:`[B, C, T_spec]` - - y_lengths: :math:`[B]` - - waveform: :math:`[B, 1, T_wav]` - - d_vectors: :math:`[B, C, 1]` - - speaker_ids: :math:`[B]` - - language_ids: :math:`[B]` - - Return Shapes: - - model_outputs: :math:`[B, 1, T_wav]` - - alignments: :math:`[B, T_seq, T_dec]` - - z: :math:`[B, C, T_dec]` - - z_p: :math:`[B, C, T_dec]` - - m_p: :math:`[B, C, T_dec]` - - logs_p: :math:`[B, C, T_dec]` - - m_q: :math:`[B, C, T_dec]` - - logs_q: :math:`[B, C, T_dec]` - - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` - - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - """ - outputs = {} - sid, g, lid = self._set_cond_input(aux_input) - # speaker embedding - if self.args.use_speaker_embedding and sid is not None: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - - # language embedding - lang_emb = None - if self.args.use_language_embedding and lid is not None: - lang_emb = self.emb_l(lid).unsqueeze(-1) - - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - - # posterior encoder - z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) - - # flow layers - z_p = self.flow(z, y_mask, g=g) - - # duration predictor - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) - - # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) - - if self.args.TTS_part_sample_rate and self.args.interpolate_z: - z = z.unsqueeze(0) # pylint: disable=not-callable - z = torch.nn.functional.interpolate( - z, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0) - y_lengths = y_lengths * self.interpolate_factor - - # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) - - o = self.waveform_decoder(z_slice.detach() if self.args.detach_z_vocoder else z_slice, g=g) - - if self.args.TTS_part_sample_rate and not self.args.interpolate_z: - slice_ids = slice_ids * int(self.interpolate_factor) - spec_segment_size = self.args.spec_segment_size * int(self.interpolate_factor) - else: - spec_segment_size = self.args.spec_segment_size - - wav_seg = segment( - waveform, - slice_ids * self.config.audio.hop_length, - spec_segment_size * self.config.audio.hop_length, - pad_short=True, - ) - - if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: - # concate generated and GT waveforms - wavs_batch = torch.cat((wav_seg, o), dim=0) - - # resample audio to speaker encoder sample_rate - # pylint: disable=W0105 - if self.audio_transform is not None: - wavs_batch = self.audio_transform(wavs_batch) - - pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True) - - # split generated and GT speaker embeddings - gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) - else: - gt_spk_emb, syn_spk_emb = None, None - - outputs.update( - { - "model_outputs": o, - "alignments": attn.squeeze(1), - "m_p": m_p, - "logs_p": logs_p, - "z": z, - "z_p": z_p, - "m_q": m_q, - "logs_q": logs_q, - "waveform_seg": wav_seg, - "gt_spk_emb": gt_spk_emb, - "syn_spk_emb": syn_spk_emb, - "slice_ids": slice_ids, - } - ) - return outputs - - @staticmethod - def _set_x_lengths(x, aux_input): - if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: - return aux_input["x_lengths"] - return torch.tensor(x.shape[1:2]).to(x.device) - - def inference( - self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} - ): # pylint: disable=dangerous-default-value - """ - Note: - To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. - - Shapes: - - x: :math:`[B, T_seq]` - - x_lengths: :math:`[B]` - - d_vectors: :math:`[B, C]` - - speaker_ids: :math:`[B]` - - Return Shapes: - - model_outputs: :math:`[B, 1, T_wav]` - - alignments: :math:`[B, T_seq, T_dec]` - - z: :math:`[B, C, T_dec]` - - z_p: :math:`[B, C, T_dec]` - - m_p: :math:`[B, C, T_dec]` - - logs_p: :math:`[B, C, T_dec]` - """ - sid, g, lid = self._set_cond_input(aux_input) - x_lengths = self._set_x_lengths(x, aux_input) - - # speaker embedding - if self.args.use_speaker_embedding and sid is not None: - g = self.emb_g(sid).unsqueeze(-1) - - # language embedding - lang_emb = None - if self.args.use_language_embedding and lid is not None: - lang_emb = self.emb_l(lid).unsqueeze(-1) - - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - - if self.args.use_sdp: - logw = self.duration_predictor( - x, - x_mask, - g=g if self.args.condition_dp_on_speaker else None, - reverse=True, - noise_scale=self.inference_noise_scale_dp, - lang_emb=lang_emb, - ) - else: - logw = self.duration_predictor( - x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb - ) - - w = torch.exp(logw) * x_mask * self.length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] - - attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] - attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) - - m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) - logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale - z = self.flow(z_p, y_mask, g=g, reverse=True) - - if self.args.TTS_part_sample_rate and self.args.interpolate_z: - z = z.unsqueeze(0) # pylint: disable=not-callable - z = torch.nn.functional.interpolate( - z, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0) - y_mask = sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) # [B, 1, T_dec_resampled] - - o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) - - outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} - return outputs - - @torch.no_grad() - def inference_voice_conversion( - self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None - ): - """Inference for voice conversion - - Args: - reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] - speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] - d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` - reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] - reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` - """ - # compute spectrograms - y = wav_to_spec( - reference_wav, - self.config.audio.fft_size, - self.config.audio.hop_length, - self.config.audio.win_length, - center=False, - ).transpose(1, 2) - y_lengths = torch.tensor([y.size(-1)]).to(y.device) - speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector - speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector - # print(y.shape, y_lengths.shape) - wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) - return wav - - def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): - """Forward pass for voice conversion - - TODO: create an end-point for voice conversion - - Args: - y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] - y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] - speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] - speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] - """ - assert self.num_speakers > 0, "num_speakers have to be larger than 0." - # speaker embedding - if self.args.use_speaker_embedding and not self.args.use_d_vector_file: - g_src = self.emb_g(speaker_cond_src).unsqueeze(-1) - g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1) - elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: - g_src = F.normalize(speaker_cond_src).unsqueeze(-1) - g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) - else: - raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") - - z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) - return o_hat, y_mask, (z, z_p, z_hat) - - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. - - Args: - batch (Dict): Input tensors. - criterion (nn.Module): Loss layer designed for the model. - optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. - - Returns: - Tuple[Dict, Dict]: Model ouputs and computed losses. - """ - - self._freeze_layers() - - mel_lens = batch["mel_lens"] - - if optimizer_idx == 0: - tokens = batch["tokens"] - token_lenghts = batch["token_lens"] - spec = batch["spec"] - spec_lens = batch["spec_lens"] - - d_vectors = batch["d_vectors"] - speaker_ids = batch["speaker_ids"] - language_ids = batch["language_ids"] - waveform = batch["waveform"] - - # generator pass - outputs = self.forward( - tokens, - token_lenghts, - spec, - spec_lens, - waveform, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, - ) - - # cache tensors for the generator pass - self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init - - # compute scores and features - scores_disc_fake, _, scores_disc_real, _ = self.disc( - outputs["model_outputs"].detach(), outputs["waveform_seg"] - ) - - # compute loss - with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - scores_disc_real, - scores_disc_fake, - ) - return outputs, loss_dict - - if optimizer_idx == 1: - mel = batch["mel"] - - # compute melspec segment - with autocast(enabled=False): - - if self.args.TTS_part_sample_rate and not self.args.interpolate_z: - slice_ids = self.model_outputs_cache["slice_ids"] * int(self.interpolate_factor) - spec_segment_size = self.args.spec_segment_size * int(self.interpolate_factor) - else: - spec_segment_size = self.args.spec_segment_size - slice_ids = self.model_outputs_cache["slice_ids"] - - mel_slice = segment( - mel.float(), slice_ids, spec_segment_size, pad_short=True - ) - mel_slice_hat = wav_to_mel( - y=self.model_outputs_cache["model_outputs"].float(), - n_fft=self.config.audio.fft_size, - sample_rate=self.config.audio.sample_rate, - num_mels=self.config.audio.num_mels, - hop_length=self.config.audio.hop_length, - win_length=self.config.audio.win_length, - fmin=self.config.audio.mel_fmin, - fmax=self.config.audio.mel_fmax, - center=False, - ) - - # compute discriminator scores and features - scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( - self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] - ) - - # compute losses - with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - mel_slice_hat=mel_slice.float(), - mel_slice=mel_slice_hat.float(), - z_p=self.model_outputs_cache["z_p"].float(), - logs_q=self.model_outputs_cache["logs_q"].float(), - m_p=self.model_outputs_cache["m_p"].float(), - logs_p=self.model_outputs_cache["logs_p"].float(), - z_len=mel_lens, - scores_disc_fake=scores_disc_fake, - feats_disc_fake=feats_disc_fake, - feats_disc_real=feats_disc_real, - loss_duration=self.model_outputs_cache["loss_duration"], - use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], - syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], - ) - - return self.model_outputs_cache, loss_dict - - raise ValueError(" [!] Unexpected `optimizer_idx`.") - - def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use - y_hat = outputs[1]["model_outputs"] - y = outputs[1]["waveform_seg"] - figures = plot_results(y_hat, y, ap, name_prefix) - sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - audios = {f"{name_prefix}/audio": sample_voice} - - alignments = outputs[1]["alignments"] - align_img = alignments[0].data.cpu().numpy().T - - figures.update( - { - "alignment": plot_alignment(align_img, output_fig=False), - } - ) - return figures, audios - - def train_log( - self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int - ): # pylint: disable=no-self-use - """Create visualizations and waveform examples. - - For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to - be projected onto Tensorboard. - - Args: - ap (AudioProcessor): audio processor used at training. - batch (Dict): Model inputs used at the previous training step. - outputs (Dict): Model outputs generated at the previoud training step. - - Returns: - Tuple[Dict, np.ndarray]: training plots and output waveform. - """ - figures, audios = self._log(self.ap, batch, outputs, "train") - logger.train_figures(steps, figures) - logger.train_audios(steps, audios, self.ap.sample_rate) - - @torch.no_grad() - def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): - return self.train_step(batch, criterion, optimizer_idx) - - def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(self.ap, batch, outputs, "eval") - logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, self.ap.sample_rate) - - def get_aux_input_from_test_sentences(self, sentence_info): - if hasattr(self.config, "model_args"): - config = self.config.model_args - else: - config = self.config - - # extract speaker and language info - text, speaker_name, style_wav, language_name = None, None, None, None - - if isinstance(sentence_info, list): - if len(sentence_info) == 1: - text = sentence_info[0] - elif len(sentence_info) == 2: - text, speaker_name = sentence_info - elif len(sentence_info) == 3: - text, speaker_name, style_wav = sentence_info - elif len(sentence_info) == 4: - text, speaker_name, style_wav, language_name = sentence_info - else: - text = sentence_info - - # get speaker id/d_vector - speaker_id, d_vector, language_id = None, None, None - if hasattr(self, "speaker_manager"): - if config.use_d_vector_file: - if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() - else: - d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False) - elif config.use_speaker_embedding: - if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() - else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] - - # get language id - if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] - - return { - "text": text, - "speaker_id": speaker_id, - "style_wav": style_wav, - "d_vector": d_vector, - "language_id": language_id, - "language_name": language_name, - } - - @torch.no_grad() - def test_run(self, assets) -> Tuple[Dict, Dict]: - """Generic test run for `tts` models used by `Trainer`. - - You can override this for a different behaviour. - - Returns: - Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. - """ - print(" | > Synthesizing test sentences.") - test_audios = {} - test_figures = {} - test_sentences = self.config.test_sentences - for idx, s_info in enumerate(test_sentences): - aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( - self, - aux_inputs["text"], - self.config, - "cuda" in str(next(self.parameters()).device), - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - language_id=aux_inputs["language_id"], - use_griffin_lim=True, - do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) - return {"figures": test_figures, "audios": test_audios} - - def test_log( - self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument - ) -> None: - logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) - logger.test_figures(steps, outputs["figures"]) - - def format_batch(self, batch: Dict) -> Dict: - """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" - speaker_ids = None - language_ids = None - d_vectors = None - - # get numerical speaker ids from speaker names - if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: - speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] - - if speaker_ids is not None: - speaker_ids = torch.LongTensor(speaker_ids) - batch["speaker_ids"] = speaker_ids - - # get d_vectors from audio file names - if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: - d_vector_mapping = self.speaker_manager.d_vectors - d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] - d_vectors = torch.FloatTensor(d_vectors) - - # get language ids from language names - if ( - self.language_manager is not None - and self.language_manager.language_id_mapping - and self.args.use_language_embedding - ): - language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] - - if language_ids is not None: - language_ids = torch.LongTensor(language_ids) - - batch["language_ids"] = language_ids - batch["d_vectors"] = d_vectors - batch["speaker_ids"] = speaker_ids - return batch - - def format_batch_on_device(self, batch): - """Compute spectrograms on the device.""" - ac = self.config.audio - - if self.args.TTS_part_sample_rate: - wav = self.audio_resampler(batch["waveform"]) - else: - wav = batch["waveform"] - - # compute spectrograms - batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) - batch["mel"] = spec_to_mel( - spec=batch["spec"], - n_fft=ac.fft_size, - num_mels=ac.num_mels, - sample_rate=ac.sample_rate, - fmin=ac.mel_fmin, - fmax=ac.mel_fmax, - ) - assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" - - # compute spectrogram frame lengths - batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() - batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() - assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 - - # zero the padding frames - batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) - batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) - return batch - - def get_data_loader( - self, - config: Coqpit, - assets: Dict, - is_eval: bool, - samples: Union[List[Dict], List[List]], - verbose: bool, - num_gpus: int, - rank: int = None, - ) -> "DataLoader": - if is_eval and not config.run_eval: - loader = None - else: - # init dataloader - dataset = VitsDataset( - samples=samples, - # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, - min_text_len=config.min_text_len, - max_text_len=config.max_text_len, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - phoneme_cache_path=config.phoneme_cache_path, - precompute_num_workers=config.precompute_num_workers, - verbose=verbose, - tokenizer=self.tokenizer, - start_by_longest=config.start_by_longest, - ) - - # wait all the DDP process to be ready - if num_gpus > 1: - dist.barrier() - - # sort input sequences from short to long - dataset.preprocess_samples() - - # get samplers - sampler = self.get_sampler(config, dataset, num_gpus) - - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. - drop_last=False, # setting this False might cause issues in AMP training. - sampler=sampler, - collate_fn=dataset.collate_fn, - num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, - pin_memory=False, - ) - return loader - - def get_optimizer(self) -> List: - """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. - Returns: - List: optimizers. - """ - # select generator parameters - optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - - gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) - optimizer1 = get_optimizer( - self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters - ) - return [optimizer0, optimizer1] - - def get_lr(self) -> List: - """Set the initial learning rates for each optimizer. - - Returns: - List: learning rates for each optimizer. - """ - return [self.config.lr_disc, self.config.lr_gen] - - def get_scheduler(self, optimizer) -> List: - """Set the schedulers for each optimizer. - - Args: - optimizer (List[`torch.optim.Optimizer`]): List of optimizers. - - Returns: - List: Schedulers, one for each optimizer. - """ - scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler_D, scheduler_G] - - def get_criterion(self): - """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in - `train_step()`""" - from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel - VitsDiscriminatorLoss, - VitsGeneratorLoss, - ) - - return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] - - def load_checkpoint( - self, - config, - checkpoint_path, - eval=False, - strict=True, - ): # pylint: disable=unused-argument, redefined-builtin - """Load the model checkpoint and setup for training or inference""" - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) - # compat band-aid for the pre-trained models to not use the encoder baked into the model - # TODO: consider baking the speaker encoder into the model and call it from there. - # as it is probably easier for model distribution. - state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} - # handle fine-tuning from a checkpoint with additional speakers - if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: - num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] - print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") - emb_g = state["model"]["emb_g.weight"] - new_row = torch.randn(num_new_speakers, emb_g.shape[1]) - emb_g = torch.cat([emb_g, new_row], axis=0) - state["model"]["emb_g.weight"] = emb_g - # load the model weights - self.load_state_dict(state["model"], strict=strict) - - if eval: - self.eval() - assert not self.training - - @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): - """Initiate model from config - - Args: - config (VitsConfig): Model config. - samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. - Defaults to None. - """ - from TTS.utils.audio import AudioProcessor - - upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() - if not config.model_args.TTS_part_sample_rate: - assert ( - upsample_rate == config.audio.hop_length - ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" - - ap = AudioProcessor.init_from_config(config, verbose=verbose) - tokenizer, new_config = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config, samples) - language_manager = LanguageManager.init_from_config(config) - - if config.model_args.speaker_encoder_model_path: - speaker_manager.init_speaker_encoder( - config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path - ) - return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) - - -################################## -# VITS CHARACTERS -################################## - - -class VitsCharacters(BaseCharacters): - """Characters class for VITs model for compatibility with pre-trained models""" - - def __init__( - self, - graphemes: str = _characters, - punctuations: str = _punctuations, - pad: str = _pad, - ipa_characters: str = _phonemes, - ) -> None: - if ipa_characters is not None: - graphemes += ipa_characters - super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) - - def _create_vocab(self): - self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] - self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - # pylint: disable=unnecessary-comprehension - self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} - - @staticmethod - def init_from_config(config: Coqpit): - if config.characters is not None: - _pad = config.characters["pad"] - _punctuations = config.characters["punctuations"] - _letters = config.characters["characters"] - _letters_ipa = config.characters["phonemes"] - return ( - VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), - config, - ) - characters = VitsCharacters() - new_config = replace(config, characters=characters.to_config()) - return characters, new_config - - def to_config(self) -> "CharactersConfig": - return CharactersConfig( - characters=self._characters, - punctuations=self._punctuations, - pad=self._pad, - eos=None, - bos=None, - blank=self._blank, - is_unique=False, - is_sorted=True, - ) diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index f6583bf5..9f9eead9 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -76,12 +76,7 @@ config.model_args.TTS_part_sample_rate = 11025 config.model_args.interpolate_z = True config.model_args.detach_z_vocoder = True -config.model_args.upsample_rates_decoder = [ - 8, - 8, - 2, - 2 - ] +config.model_args.upsample_rates_decoder = [8, 8, 2, 2] config.save_json(config_path)