mirror of https://github.com/coqui-ai/TTS.git
Update VITS for the new API
This commit is contained in:
parent
f802a931a3
commit
ea965a5683
|
@ -1,7 +1,8 @@
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
import random
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -10,6 +11,7 @@ from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
|
@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
@ -283,91 +286,79 @@ class Vits(BaseTTS):
|
||||||
self.END2END = True
|
self.END2END = True
|
||||||
self.speaker_manager = speaker_manager
|
self.speaker_manager = speaker_manager
|
||||||
self.language_manager = language_manager
|
self.language_manager = language_manager
|
||||||
if config.__class__.__name__ == "VitsConfig":
|
|
||||||
# loading from VitsConfig
|
|
||||||
self.num_chars = self.tokenizer.characters.num_chars
|
|
||||||
self.config = config
|
|
||||||
args = self.config.model_args
|
|
||||||
elif isinstance(config, VitsArgs):
|
|
||||||
# loading from VitsArgs
|
|
||||||
self.config = config
|
|
||||||
args = config
|
|
||||||
else:
|
|
||||||
raise ValueError("config must be either a VitsConfig or VitsArgs")
|
|
||||||
|
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
self.init_multilingual(config)
|
self.init_multilingual(config)
|
||||||
|
|
||||||
self.length_scale = args.length_scale
|
self.length_scale = self.args.length_scale
|
||||||
self.noise_scale = args.noise_scale
|
self.noise_scale = self.args.noise_scale
|
||||||
self.inference_noise_scale = args.inference_noise_scale
|
self.inference_noise_scale = self.args.inference_noise_scale
|
||||||
self.inference_noise_scale_dp = args.inference_noise_scale_dp
|
self.inference_noise_scale_dp = self.args.inference_noise_scale_dp
|
||||||
self.noise_scale_dp = args.noise_scale_dp
|
self.noise_scale_dp = self.args.noise_scale_dp
|
||||||
self.max_inference_len = args.max_inference_len
|
self.max_inference_len = self.args.max_inference_len
|
||||||
self.spec_segment_size = args.spec_segment_size
|
self.spec_segment_size = self.args.spec_segment_size
|
||||||
|
|
||||||
self.text_encoder = TextEncoder(
|
self.text_encoder = TextEncoder(
|
||||||
args.num_chars,
|
self.args.num_chars,
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
args.hidden_channels_ffn_text_encoder,
|
self.args.hidden_channels_ffn_text_encoder,
|
||||||
args.num_heads_text_encoder,
|
self.args.num_heads_text_encoder,
|
||||||
args.num_layers_text_encoder,
|
self.args.num_layers_text_encoder,
|
||||||
args.kernel_size_text_encoder,
|
self.args.kernel_size_text_encoder,
|
||||||
args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
args.out_channels,
|
self.args.out_channels,
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
kernel_size=args.kernel_size_posterior_encoder,
|
kernel_size=self.args.kernel_size_posterior_encoder,
|
||||||
dilation_rate=args.dilation_rate_posterior_encoder,
|
dilation_rate=self.args.dilation_rate_posterior_encoder,
|
||||||
num_layers=args.num_layers_posterior_encoder,
|
num_layers=self.args.num_layers_posterior_encoder,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.flow = ResidualCouplingBlocks(
|
self.flow = ResidualCouplingBlocks(
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
kernel_size=args.kernel_size_flow,
|
kernel_size=self.args.kernel_size_flow,
|
||||||
dilation_rate=args.dilation_rate_flow,
|
dilation_rate=self.args.dilation_rate_flow,
|
||||||
num_layers=args.num_layers_flow,
|
num_layers=self.args.num_layers_flow,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_sdp:
|
if self.args.use_sdp:
|
||||||
self.duration_predictor = StochasticDurationPredictor(
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
192,
|
192,
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
4,
|
4,
|
||||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
256,
|
256,
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
1,
|
1,
|
||||||
args.resblock_type_decoder,
|
self.args.resblock_type_decoder,
|
||||||
args.resblock_dilation_sizes_decoder,
|
self.args.resblock_dilation_sizes_decoder,
|
||||||
args.resblock_kernel_sizes_decoder,
|
self.args.resblock_kernel_sizes_decoder,
|
||||||
args.upsample_kernel_sizes_decoder,
|
self.args.upsample_kernel_sizes_decoder,
|
||||||
args.upsample_initial_channel_decoder,
|
self.args.upsample_initial_channel_decoder,
|
||||||
args.upsample_rates_decoder,
|
self.args.upsample_rates_decoder,
|
||||||
inference_padding=0,
|
inference_padding=0,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
conv_pre_weight_norm=False,
|
conv_pre_weight_norm=False,
|
||||||
|
@ -375,8 +366,8 @@ class Vits(BaseTTS):
|
||||||
conv_post_bias=False,
|
conv_post_bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.init_discriminator:
|
if self.args.init_discriminator:
|
||||||
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
@ -883,19 +874,17 @@ class Vits(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
"""
|
"""
|
||||||
ap = assets["audio_processor"]
|
self._log(self.ap, batch, outputs, "train")
|
||||||
self._log(ap, batch, outputs, "train")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
ap = assets["audio_processor"]
|
return self._log(self.ap, batch, outputs, "eval")
|
||||||
return self._log(ap, batch, outputs, "eval")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
def test_run(self) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -990,36 +979,6 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_symbols(config):
|
|
||||||
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
|
|
||||||
whole training and inference steps."""
|
|
||||||
_pad = config.characters["pad"]
|
|
||||||
_punctuations = config.characters["punctuations"]
|
|
||||||
_letters = config.characters["characters"]
|
|
||||||
_letters_ipa = config.characters["phonemes"]
|
|
||||||
symbols = [_pad] + list(_punctuations) + list(_letters)
|
|
||||||
if config.use_phonemes:
|
|
||||||
symbols += list(_letters_ipa)
|
|
||||||
return symbols
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_characters(config: Coqpit):
|
|
||||||
if config.characters is not None:
|
|
||||||
symbols = Vits.make_symbols(config)
|
|
||||||
else:
|
|
||||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
|
||||||
parse_symbols,
|
|
||||||
phonemes,
|
|
||||||
symbols,
|
|
||||||
)
|
|
||||||
|
|
||||||
config.characters = parse_symbols()
|
|
||||||
if config.use_phonemes:
|
|
||||||
symbols = phonemes
|
|
||||||
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
|
||||||
return symbols, config, num_chars
|
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
@ -1035,23 +994,65 @@ class Vits(BaseTTS):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "Coqpit"):
|
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
|
||||||
"""Initialize model from config."""
|
"""Initiate model from config
|
||||||
|
|
||||||
# init characters
|
|
||||||
if config.use_phonemes:
|
|
||||||
from TTS.tts.utils.text.characters import IPAPhonemes
|
|
||||||
|
|
||||||
characters = IPAPhonemes().init_from_config(config)
|
|
||||||
else:
|
|
||||||
from TTS.tts.utils.text.characters import Graphemes
|
|
||||||
|
|
||||||
characters = Graphemes().init_from_config(config)
|
|
||||||
config.num_chars = characters.num_chars
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (VitsConfig): Model config.
|
||||||
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
ap = AudioProcessor.init_from_config(config)
|
ap = AudioProcessor.init_from_config(config)
|
||||||
tokenizer = TTSTokenizer.init_from_config(config)
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
speaker_manager = SpeakerManager.init_from_config(config)
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||||
return Vits(config, ap, tokenizer, speaker_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
|
|
||||||
|
class VitsCharacters(BaseCharacters):
|
||||||
|
"""Characters class for VITs model for compatibility with pre-trained models"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graphemes: str = _characters,
|
||||||
|
punctuations: str = _punctuations,
|
||||||
|
pad: str = _pad,
|
||||||
|
ipa_characters: str = _phonemes,
|
||||||
|
) -> None:
|
||||||
|
if ipa_characters is not None:
|
||||||
|
graphemes += ipa_characters
|
||||||
|
super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
|
||||||
|
|
||||||
|
def _create_vocab(self):
|
||||||
|
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
||||||
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
|
# pylint: disable=unnecessary-comprehension
|
||||||
|
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: Coqpit):
|
||||||
|
if config.characters is not None:
|
||||||
|
_pad = config.characters["pad"]
|
||||||
|
_punctuations = config.characters["punctuations"]
|
||||||
|
_letters = config.characters["characters"]
|
||||||
|
_letters_ipa = config.characters["phonemes"]
|
||||||
|
return (
|
||||||
|
VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
characters = VitsCharacters()
|
||||||
|
new_config = replace(config, characters=characters.to_config())
|
||||||
|
return characters, new_config
|
||||||
|
|
||||||
|
def to_config(self) -> "CharactersConfig":
|
||||||
|
return CharactersConfig(
|
||||||
|
characters=self._characters,
|
||||||
|
punctuations=self._punctuations,
|
||||||
|
pad=self._pad,
|
||||||
|
eos=None,
|
||||||
|
bos=None,
|
||||||
|
blank=self._blank,
|
||||||
|
is_unique=False,
|
||||||
|
is_sorted=True,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue