Update VITS for the new API

This commit is contained in:
Eren Gölge 2021-12-07 12:55:18 +00:00
parent f802a931a3
commit ea965a5683
1 changed files with 107 additions and 106 deletions

View File

@ -1,7 +1,8 @@
import math import math
from dataclasses import dataclass, field import random
from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Union
import torch import torch
import torchaudio import torchaudio
@ -10,6 +11,7 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -283,91 +286,79 @@ class Vits(BaseTTS):
self.END2END = True self.END2END = True
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.language_manager = language_manager self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
self.num_chars = self.tokenizer.characters.num_chars
self.config = config
args = self.config.model_args
elif isinstance(config, VitsArgs):
# loading from VitsArgs
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or VitsArgs")
self.args = args self.args = args
self.init_multispeaker(config) self.init_multispeaker(config)
self.init_multilingual(config) self.init_multilingual(config)
self.length_scale = args.length_scale self.length_scale = self.args.length_scale
self.noise_scale = args.noise_scale self.noise_scale = self.args.noise_scale
self.inference_noise_scale = args.inference_noise_scale self.inference_noise_scale = self.args.inference_noise_scale
self.inference_noise_scale_dp = args.inference_noise_scale_dp self.inference_noise_scale_dp = self.args.inference_noise_scale_dp
self.noise_scale_dp = args.noise_scale_dp self.noise_scale_dp = self.args.noise_scale_dp
self.max_inference_len = args.max_inference_len self.max_inference_len = self.args.max_inference_len
self.spec_segment_size = args.spec_segment_size self.spec_segment_size = self.args.spec_segment_size
self.text_encoder = TextEncoder( self.text_encoder = TextEncoder(
args.num_chars, self.args.num_chars,
args.hidden_channels, self.args.hidden_channels,
args.hidden_channels, self.args.hidden_channels,
args.hidden_channels_ffn_text_encoder, self.args.hidden_channels_ffn_text_encoder,
args.num_heads_text_encoder, self.args.num_heads_text_encoder,
args.num_layers_text_encoder, self.args.num_layers_text_encoder,
args.kernel_size_text_encoder, self.args.kernel_size_text_encoder,
args.dropout_p_text_encoder, self.args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim,
) )
self.posterior_encoder = PosteriorEncoder( self.posterior_encoder = PosteriorEncoder(
args.out_channels, self.args.out_channels,
args.hidden_channels, self.args.hidden_channels,
args.hidden_channels, self.args.hidden_channels,
kernel_size=args.kernel_size_posterior_encoder, kernel_size=self.args.kernel_size_posterior_encoder,
dilation_rate=args.dilation_rate_posterior_encoder, dilation_rate=self.args.dilation_rate_posterior_encoder,
num_layers=args.num_layers_posterior_encoder, num_layers=self.args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim,
) )
self.flow = ResidualCouplingBlocks( self.flow = ResidualCouplingBlocks(
args.hidden_channels, self.args.hidden_channels,
args.hidden_channels, self.args.hidden_channels,
kernel_size=args.kernel_size_flow, kernel_size=self.args.kernel_size_flow,
dilation_rate=args.dilation_rate_flow, dilation_rate=self.args.dilation_rate_flow,
num_layers=args.num_layers_flow, num_layers=self.args.num_layers_flow,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim,
) )
if args.use_sdp: if self.args.use_sdp:
self.duration_predictor = StochasticDurationPredictor( self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels, self.args.hidden_channels,
192, 192,
3, 3,
args.dropout_p_duration_predictor, self.args.dropout_p_duration_predictor,
4, 4,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
else: else:
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
args.hidden_channels, self.args.hidden_channels,
256, 256,
3, 3,
args.dropout_p_duration_predictor, self.args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
self.waveform_decoder = HifiganGenerator( self.waveform_decoder = HifiganGenerator(
args.hidden_channels, self.args.hidden_channels,
1, 1,
args.resblock_type_decoder, self.args.resblock_type_decoder,
args.resblock_dilation_sizes_decoder, self.args.resblock_dilation_sizes_decoder,
args.resblock_kernel_sizes_decoder, self.args.resblock_kernel_sizes_decoder,
args.upsample_kernel_sizes_decoder, self.args.upsample_kernel_sizes_decoder,
args.upsample_initial_channel_decoder, self.args.upsample_initial_channel_decoder,
args.upsample_rates_decoder, self.args.upsample_rates_decoder,
inference_padding=0, inference_padding=0,
cond_channels=self.embedded_speaker_dim, cond_channels=self.embedded_speaker_dim,
conv_pre_weight_norm=False, conv_pre_weight_norm=False,
@ -375,8 +366,8 @@ class Vits(BaseTTS):
conv_post_bias=False, conv_post_bias=False,
) )
if args.init_discriminator: if self.args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit): def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
@ -883,19 +874,17 @@ class Vits(BaseTTS):
Returns: Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform. Tuple[Dict, np.ndarray]: training plots and output waveform.
""" """
ap = assets["audio_processor"] self._log(self.ap, batch, outputs, "train")
self._log(ap, batch, outputs, "train")
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx) return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] return self._log(self.ap, batch, outputs, "eval")
return self._log(ap, batch, outputs, "eval")
@torch.no_grad() @torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]: def test_run(self) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -990,36 +979,6 @@ class Vits(BaseTTS):
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
@staticmethod
def make_symbols(config):
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
whole training and inference steps."""
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
symbols = [_pad] + list(_punctuations) + list(_letters)
if config.use_phonemes:
symbols += list(_letters_ipa)
return symbols
@staticmethod
def get_characters(config: Coqpit):
if config.characters is not None:
symbols = Vits.make_symbols(config)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
if config.use_phonemes:
symbols = phonemes
num_chars = len(symbols) + getattr(config, "add_blank", False)
return symbols, config, num_chars
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
@ -1035,23 +994,65 @@ class Vits(BaseTTS):
assert not self.training assert not self.training
@staticmethod @staticmethod
def init_from_config(config: "Coqpit"): def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
"""Initialize model from config.""" """Initiate model from config
# init characters
if config.use_phonemes:
from TTS.tts.utils.text.characters import IPAPhonemes
characters = IPAPhonemes().init_from_config(config)
else:
from TTS.tts.utils.text.characters import Graphemes
characters = Graphemes().init_from_config(config)
config.num_chars = characters.num_chars
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config) ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples)
return Vits(config, ap, tokenizer, speaker_manager) return Vits(new_config, ap, tokenizer, speaker_manager)
class VitsCharacters(BaseCharacters):
"""Characters class for VITs model for compatibility with pre-trained models"""
def __init__(
self,
graphemes: str = _characters,
punctuations: str = _punctuations,
pad: str = _pad,
ipa_characters: str = _phonemes,
) -> None:
if ipa_characters is not None:
graphemes += ipa_characters
super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
def _create_vocab(self):
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
# pylint: disable=unnecessary-comprehension
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
@staticmethod
def init_from_config(config: Coqpit):
if config.characters is not None:
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
return (
VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
config,
)
characters = VitsCharacters()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
characters=self._characters,
punctuations=self._punctuations,
pad=self._pad,
eos=None,
bos=None,
blank=self._blank,
is_unique=False,
is_sorted=True,
)