From ea965a5683c56a39570b4cc91e86cd2bb9799308 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:55:18 +0000 Subject: [PATCH] Update VITS for the new API --- TTS/tts/models/vits.py | 213 +++++++++++++++++++++-------------------- 1 file changed, 107 insertions(+), 106 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 1de26913..30dc7ec4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,7 +1,8 @@ import math -from dataclasses import dataclass, field +import random +from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch import torchaudio @@ -10,6 +11,7 @@ from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder @@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler @@ -283,91 +286,79 @@ class Vits(BaseTTS): self.END2END = True self.speaker_manager = speaker_manager self.language_manager = language_manager - if config.__class__.__name__ == "VitsConfig": - # loading from VitsConfig - self.num_chars = self.tokenizer.characters.num_chars - self.config = config - args = self.config.model_args - elif isinstance(config, VitsArgs): - # loading from VitsArgs - self.config = config - args = config - else: - raise ValueError("config must be either a VitsConfig or VitsArgs") self.args = args self.init_multispeaker(config) self.init_multilingual(config) - self.length_scale = args.length_scale - self.noise_scale = args.noise_scale - self.inference_noise_scale = args.inference_noise_scale - self.inference_noise_scale_dp = args.inference_noise_scale_dp - self.noise_scale_dp = args.noise_scale_dp - self.max_inference_len = args.max_inference_len - self.spec_segment_size = args.spec_segment_size + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size self.text_encoder = TextEncoder( - args.num_chars, - args.hidden_channels, - args.hidden_channels, - args.hidden_channels_ffn_text_encoder, - args.num_heads_text_encoder, - args.num_layers_text_encoder, - args.kernel_size_text_encoder, - args.dropout_p_text_encoder, - language_emb_dim=self.embedded_language_dim, + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, ) self.posterior_encoder = PosteriorEncoder( - args.out_channels, - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_posterior_encoder, - dilation_rate=args.dilation_rate_posterior_encoder, - num_layers=args.num_layers_posterior_encoder, + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, cond_channels=self.embedded_speaker_dim, ) self.flow = ResidualCouplingBlocks( - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_flow, - dilation_rate=args.dilation_rate_flow, - num_layers=args.num_layers_flow, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, cond_channels=self.embedded_speaker_dim, ) - if args.use_sdp: + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 192, 3, - args.dropout_p_duration_predictor, + self.args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 256, 3, - args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( - args.hidden_channels, + self.args.hidden_channels, 1, - args.resblock_type_decoder, - args.resblock_dilation_sizes_decoder, - args.resblock_kernel_sizes_decoder, - args.upsample_kernel_sizes_decoder, - args.upsample_initial_channel_decoder, - args.upsample_rates_decoder, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, inference_padding=0, cond_channels=self.embedded_speaker_dim, conv_pre_weight_norm=False, @@ -375,8 +366,8 @@ class Vits(BaseTTS): conv_post_bias=False, ) - if args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) + if self.args.init_discriminator: + self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -883,19 +874,17 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - ap = assets["audio_processor"] - self._log(ap, batch, outputs, "train") + self._log(self.ap, batch, outputs, "train") @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - return self._log(ap, batch, outputs, "eval") + return self._log(self.ap, batch, outputs, "eval") @torch.no_grad() - def test_run(self, ap) -> Tuple[Dict, Dict]: + def test_run(self) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -990,36 +979,6 @@ class Vits(BaseTTS): return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] - @staticmethod - def make_symbols(config): - """Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the - whole training and inference steps.""" - _pad = config.characters["pad"] - _punctuations = config.characters["punctuations"] - _letters = config.characters["characters"] - _letters_ipa = config.characters["phonemes"] - symbols = [_pad] + list(_punctuations) + list(_letters) - if config.use_phonemes: - symbols += list(_letters_ipa) - return symbols - - @staticmethod - def get_characters(config: Coqpit): - if config.characters is not None: - symbols = Vits.make_symbols(config) - else: - from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel - parse_symbols, - phonemes, - symbols, - ) - - config.characters = parse_symbols() - if config.use_phonemes: - symbols = phonemes - num_chars = len(symbols) + getattr(config, "add_blank", False) - return symbols, config, num_chars - def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin @@ -1035,23 +994,65 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "Coqpit"): - """Initialize model from config.""" - - # init characters - if config.use_phonemes: - from TTS.tts.utils.text.characters import IPAPhonemes - - characters = IPAPhonemes().init_from_config(config) - else: - from TTS.tts.utils.text.characters import Graphemes - - characters = Graphemes().init_from_config(config) - config.num_chars = characters.num_chars + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ from TTS.utils.audio import AudioProcessor ap = AudioProcessor.init_from_config(config) - tokenizer = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config) - return Vits(config, ap, tokenizer, speaker_manager) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Vits(new_config, ap, tokenizer, speaker_manager) + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + )