mirror of https://github.com/coqui-ai/TTS.git
Create base 🐸TTS model abstraction for tts models
This commit is contained in:
parent
a358f74a52
commit
7b8c15ac49
|
@ -1,9 +1,9 @@
|
|||
from coqpit import Coqpit
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -11,8 +11,8 @@ from TTS.utils.audio import AudioProcessor
|
|||
# pylint: skip-file
|
||||
|
||||
|
||||
class TTSModel(nn.Module, ABC):
|
||||
"""Abstract TTS class. Every new `tts` model must inherit this.
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
@ -77,7 +77,6 @@ class TTSModel(nn.Module, ABC):
|
|||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
@abstractmethod
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""Create visualizations and waveform examples for training.
|
||||
|
||||
|
@ -92,10 +91,7 @@ class TTSModel(nn.Module, ABC):
|
|||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
figures_dict = {}
|
||||
output_wav = np.array()
|
||||
...
|
||||
return figures_dict, output_wav
|
||||
return None, None
|
||||
|
||||
@abstractmethod
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
|
@ -114,13 +110,9 @@ class TTSModel(nn.Module, ABC):
|
|||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
@abstractmethod
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""The same as `train_log()`"""
|
||||
figures_dict = {}
|
||||
output_wav = np.array()
|
||||
...
|
||||
return figures_dict, output_wav
|
||||
return None, None
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
||||
|
@ -132,3 +124,24 @@ class TTSModel(nn.Module, ABC):
|
|||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||
"""Setup an return optimizer or optimizers."""
|
||||
pass
|
||||
|
||||
def get_lr(self) -> Union[float, List[float]]:
|
||||
"""Return learning rate(s).
|
||||
|
||||
Returns:
|
||||
Union[float, List[float]]: Model's initial learning rates.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
||||
pass
|
||||
|
||||
def get_criterion(self):
|
||||
pass
|
||||
|
||||
def format_batch(self):
|
||||
pass
|
|
@ -1,5 +1,9 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
@ -7,36 +11,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class AlignTTS(TTSModel):
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
AlignTTS's Abstract - Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||
mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
|
||||
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
|
||||
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
|
||||
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
|
||||
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
|
||||
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
|
||||
|
||||
Note:
|
||||
Original model uses a separate character embedding layer for duration predictor. However, it causes the
|
||||
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
|
||||
we predict durations based on encoder outputs which has higher level information about input characters. This
|
||||
enables training without phases as in the original paper.
|
||||
|
||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||
|
||||
@dataclass
|
||||
class AlignTTSArgs(Coqpit):
|
||||
"""
|
||||
Args:
|
||||
num_chars (int):
|
||||
number of unique input to characters
|
||||
|
@ -64,43 +48,98 @@ class AlignTTS(TTSModel):
|
|||
number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 256
|
||||
hidden_channels_dp: int = 256
|
||||
encoder_type: str = "fftransformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
decoder_type: str = "fftransformer"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
length_scale: float = 1.0
|
||||
num_speakers: int = 0
|
||||
use_speaker_embedding: bool = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
||||
|
||||
class AlignTTS(BaseTTS):
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
Check ```AlignTTSArgs``` for the class arguments.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import AlignTTSConfig
|
||||
>>> config = AlignTTSConfig()
|
||||
>>> config.model_args.num_chars = 50
|
||||
>>> model = AlignTTS(config)
|
||||
|
||||
Paper Abstract:
|
||||
Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||
mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
|
||||
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
|
||||
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
|
||||
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
|
||||
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
|
||||
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
|
||||
|
||||
Note:
|
||||
Original model uses a separate character embedding layer for duration predictor. However, it causes the
|
||||
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
|
||||
we predict durations based on encoder outputs which has higher level information about input characters. This
|
||||
enables training without phases as in the original paper.
|
||||
|
||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type="fftransformer",
|
||||
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
decoder_type="fftransformer",
|
||||
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
def __init__(self, config: Coqpit):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.phase = -1
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
if isinstance(config.model_args.length_scale, int)
|
||||
else config.model_args.length_scale
|
||||
)
|
||||
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.mdn_block = MDNBlock(hidden_channels, 2 * out_channels)
|
||||
self.embedded_speaker_dim = 0
|
||||
self.init_multispeaker(config)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
|
||||
self.encoder = Encoder(
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.encoder_type,
|
||||
config.model_args.encoder_params,
|
||||
self.embedded_speaker_dim,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.model_args.out_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.decoder_type,
|
||||
config.model_args.decoder_params,
|
||||
)
|
||||
self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp)
|
||||
|
||||
if c_in_channels > 0 and c_in_channels != hidden_channels:
|
||||
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
|
||||
self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1)
|
||||
|
||||
self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels)
|
||||
|
||||
if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1)
|
||||
|
||||
@staticmethod
|
||||
def compute_log_probs(mu, log_sigma, y):
|
||||
|
@ -164,11 +203,12 @@ class AlignTTS(TTSModel):
|
|||
# project g to decoder dim.
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g)
|
||||
|
||||
return x + g
|
||||
|
||||
def _forward_encoder(self, x, x_lengths, g=None):
|
||||
if hasattr(self, "emb_g"):
|
||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||
g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1]
|
||||
|
||||
if g is not None:
|
||||
g = g.unsqueeze(-1)
|
||||
|
@ -315,7 +355,9 @@ class AlignTTS(TTSModel):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
def train_log(
|
||||
self, ap: AudioProcessor, batch: dict, outputs: dict
|
||||
) -> Tuple[Dict, Dict]: # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
@ -332,7 +374,7 @@ class AlignTTS(TTSModel):
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, train_audio
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
@ -349,6 +391,11 @@ class AlignTTS(TTSModel):
|
|||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return AlignTTSLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def _set_phase(config, global_step):
|
||||
"""Decide AlignTTS training phase"""
|
||||
|
|
|
@ -0,0 +1,286 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from coqpit import MISSING, Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseTacotronArgs(Coqpit):
|
||||
"""TODO: update Tacotron configs using it"""
|
||||
|
||||
num_chars: int = MISSING
|
||||
num_speakers: int = MISSING
|
||||
r: int = MISSING
|
||||
out_channels: int = 80
|
||||
decoder_output_dim: int = 80
|
||||
attn_type: str = "original"
|
||||
attn_win: bool = False
|
||||
attn_norm: str = "softmax"
|
||||
prenet_type: str = "original"
|
||||
prenet_dropout: bool = True
|
||||
prenet_dropout_at_inference: bool = False
|
||||
forward_attn: bool = False
|
||||
trans_agent: bool = False
|
||||
forward_attn_mask: bool = False
|
||||
location_attn: bool = True
|
||||
attn_K: int = 5
|
||||
separate_stopnet: bool = True
|
||||
bidirectional_decoder: bool = False
|
||||
double_decoder_consistency: bool = False
|
||||
ddc_r: int = None
|
||||
encoder_in_features: int = 512
|
||||
decoder_in_features: int = 512
|
||||
d_vector_dim: int = None
|
||||
use_gst: bool = False
|
||||
gst: bool = None
|
||||
gradual_training: bool = None
|
||||
|
||||
|
||||
class BaseTacotron(BaseTTS):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# init tensors
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
# init multi-speaker layers
|
||||
self.init_multispeaker(config)
|
||||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if "r" in state:
|
||||
self.decoder.set_r(state["r"])
|
||||
else:
|
||||
self.decoder.set_r(state["config"]["r"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self) -> nn.Module:
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||
parse_symbols,
|
||||
phonemes,
|
||||
symbols,
|
||||
)
|
||||
|
||||
config.characters = parse_symbols()
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
return model_characters, config
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Compute Tacotron's auxiliary inputs based on model config.
|
||||
- speaker d_vector
|
||||
- style wav for GST
|
||||
- speaker ID for speaker embedding
|
||||
"""
|
||||
# setup speaker_id
|
||||
if self.config.use_speaker_embedding:
|
||||
speaker_id = kwargs.get("speaker_id", 0)
|
||||
else:
|
||||
speaker_id = None
|
||||
# setup d_vector
|
||||
d_vector = (
|
||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
||||
if self.config.use_d_vector_file and self.config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
# setup style_mel
|
||||
if "style_wav" in kwargs:
|
||||
style_wav = kwargs["style_wav"]
|
||||
elif self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
else:
|
||||
style_wav = None
|
||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return aux_inputs
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
input_mask = sequence_mask(text_lengths)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
"""Run backwards decoder"""
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
"""Double Decoder Consistency"""
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
"""Compute speaker embedding vectors"""
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + embedded_speakers_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
|
||||
return outputs
|
||||
|
||||
#############################
|
||||
# CALLBACKS
|
||||
#############################
|
||||
|
||||
def on_epoch_start(self, trainer):
|
||||
"""Callback for setting values wrt gradual training schedule.
|
||||
|
||||
Args:
|
||||
trainer (TrainerTTS): TTS trainer object that is used to train this model.
|
||||
"""
|
||||
if self.gradual_training:
|
||||
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
|
||||
trainer.config.r = r
|
||||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True)
|
||||
trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
|
@ -0,0 +1,233 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.tts.datasets import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTTS(BaseModel):
|
||||
"""Abstract `tts` class. Every new `tts` model must inherit this.
|
||||
|
||||
It defines `tts` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = parse_symbols()
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
return model_characters, config
|
||||
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
pass
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
item_idx = batch[7]
|
||||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager"):
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = (
|
||||
self.speaker_manager.d_vectors
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None
|
||||
)
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron",
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
tp=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def test_run(self) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
sen,
|
||||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
return test_figures, test_audios
|
|
@ -4,131 +4,89 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class GlowTTS(TTSModel):
|
||||
class GlowTTS(BaseTTS):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
|
||||
Args:
|
||||
num_chars (int): number of embedding characters.
|
||||
hidden_channels_enc (int): number of embedding and encoder channels.
|
||||
hidden_channels_dec (int): number of decoder channels.
|
||||
use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
|
||||
hidden_channels_dp (int): number of duration predictor channels.
|
||||
out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
|
||||
num_flow_blocks_dec (int): number of decoder blocks.
|
||||
kernel_size_dec (int): decoder kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_block_layers (int): number of decoder layers in each decoder block.
|
||||
dropout_p_dec (float): dropout rate for decoder.
|
||||
num_speaker (int): number of speaker to define the size of speaker embedding layer.
|
||||
c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
|
||||
num_splits (int): number of split levels in inversible conv1x1 operation.
|
||||
num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
|
||||
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
|
||||
encoder_type (str): encoder module type.
|
||||
encoder_params (dict): encoder module parameters.
|
||||
d_vector_dim (int): channels of external speaker embedding vectors.
|
||||
Paper abstract:
|
||||
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
||||
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
||||
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
||||
a flow-based generative model for parallel TTS that does not require any external aligner. By combining the
|
||||
properties of flows and dynamic programming, the proposed model searches for the most probable monotonic
|
||||
alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard
|
||||
monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows
|
||||
enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over
|
||||
the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
|
||||
model can be easily extended to a multi-speaker setting.
|
||||
|
||||
Check `GlowTTSConfig` for class arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
hidden_channels_enc,
|
||||
hidden_channels_dec,
|
||||
use_encoder_prenet,
|
||||
hidden_channels_dp,
|
||||
out_channels,
|
||||
num_flow_blocks_dec=12,
|
||||
inference_noise_scale=0.33,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dp=0.1,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
mean_only=False,
|
||||
encoder_type="transformer",
|
||||
encoder_params=None,
|
||||
d_vector_dim=None,
|
||||
):
|
||||
def __init__(self, config: GlowTTSConfig):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.hidden_channels_dp = hidden_channels_dp
|
||||
self.hidden_channels_enc = hidden_channels_enc
|
||||
self.hidden_channels_dec = hidden_channels_dec
|
||||
self.out_channels = out_channels
|
||||
self.num_flow_blocks_dec = num_flow_blocks_dec
|
||||
self.kernel_size_dec = kernel_size_dec
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_block_layers = num_block_layers
|
||||
self.dropout_p_dec = dropout_p_dec
|
||||
self.num_speakers = num_speakers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.num_splits = num_splits
|
||||
self.num_squeeze = num_squeeze
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
self.mean_only = mean_only
|
||||
self.use_encoder_prenet = use_encoder_prenet
|
||||
self.inference_noise_scale = inference_noise_scale
|
||||
|
||||
# model constants.
|
||||
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
||||
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||
self.d_vector_dim = d_vector_dim
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
self.config = config
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
if num_speakers > 1:
|
||||
if self.c_in_channels == 0 and not self.d_vector_dim:
|
||||
self.c_in_channels = 0
|
||||
if self.num_speakers > 1:
|
||||
if self.d_vector_dim:
|
||||
self.c_in_channels = self.d_vector_dim
|
||||
elif self.c_in_channels == 0 and not self.d_vector_dim:
|
||||
# TODO: make this adjustable
|
||||
self.c_in_channels = 256
|
||||
elif self.d_vector_dim:
|
||||
self.c_in_channels = self.d_vector_dim
|
||||
|
||||
self.encoder = Encoder(
|
||||
num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels_enc,
|
||||
hidden_channels_dp=hidden_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
encoder_params=encoder_params,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
dropout_p_dp=dropout_p_dp,
|
||||
self.num_chars,
|
||||
out_channels=self.out_channels,
|
||||
hidden_channels=self.hidden_channels_enc,
|
||||
hidden_channels_dp=self.hidden_channels_dp,
|
||||
encoder_type=self.encoder_type,
|
||||
encoder_params=self.encoder_params,
|
||||
mean_only=self.mean_only,
|
||||
use_prenet=self.use_encoder_prenet,
|
||||
dropout_p_dp=self.dropout_p_dp,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
out_channels,
|
||||
hidden_channels_dec,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_squeeze=num_squeeze,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
self.out_channels,
|
||||
self.hidden_channels_dec,
|
||||
self.kernel_size_dec,
|
||||
self.dilation_rate,
|
||||
self.num_flow_blocks_dec,
|
||||
self.num_block_layers,
|
||||
dropout_p=self.dropout_p_dec,
|
||||
num_splits=self.num_splits,
|
||||
num_squeeze=self.num_squeeze,
|
||||
sigmoid_scale=self.sigmoid_scale,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
if num_speakers > 1 and not d_vector_dim:
|
||||
if self.num_speakers > 1 and not self.d_vector_dim:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
|
@ -377,7 +335,7 @@ class GlowTTS(TTSModel):
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, train_audio
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
@ -406,3 +364,8 @@ class GlowTTS(TTSModel):
|
|||
self.eval()
|
||||
self.store_inverse()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return GlowTTSLoss()
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
@ -6,25 +9,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class SpeedySpeech(TTSModel):
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
This model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
|
||||
@dataclass
|
||||
class SpeedySpeechArgs(Coqpit):
|
||||
"""
|
||||
Args:
|
||||
num_chars (int): number of unique input to characters
|
||||
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size.
|
||||
|
@ -36,49 +30,107 @@ class SpeedySpeech(TTSModel):
|
|||
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'.
|
||||
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }.
|
||||
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0.
|
||||
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
use_d_vector (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
d_vector_dim (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 128
|
||||
num_speakers: int = 0
|
||||
positional_encoding: bool = True
|
||||
length_scale: int = 1
|
||||
encoder_type: str = "residual_conv_bn"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13,
|
||||
}
|
||||
)
|
||||
decoder_type: str = "residual_conv_bn"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
}
|
||||
)
|
||||
use_d_vector: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
||||
|
||||
class SpeedySpeech(BaseTTS):
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
Paper abstract:
|
||||
While recent neural sequence-to-sequence models have greatly improved the quality of speech
|
||||
synthesis, there has not been a system capable of fast training, fast inference and high-quality audio synthesis
|
||||
at the same time. We propose a student-teacher network capable of high-quality faster-than-real-time spectrogram
|
||||
synthesis, with low requirements on computational resources and fast training time. We show that self-attention
|
||||
layers are not necessary for generation of high quality audio. We utilize simple convolutional blocks with
|
||||
residual connections in both student and teacher networks and use only a single attention layer in the teacher
|
||||
model. Coupled with a MelGAN vocoder, our model's voice quality was rated significantly higher than Tacotron 2.
|
||||
Our model can be efficiently trained on a single GPU and can run in real time even on a CPU. We provide both
|
||||
our source code and audio samples in our GitHub repository.
|
||||
|
||||
Notes:
|
||||
The vanilla model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
|
||||
You can also mix and match different encoder and decoder networks beyond the paper.
|
||||
|
||||
Check `SpeedySpeechArgs` for arguments.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||
self.config = config
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
if "characters" in config:
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
if isinstance(config.model_args.length_scale, int)
|
||||
else config.model_args.length_scale
|
||||
)
|
||||
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
|
||||
self.encoder = Encoder(
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.encoder_type,
|
||||
config.model_args.encoder_params,
|
||||
config.model_args.d_vector_dim,
|
||||
)
|
||||
if config.model_args.positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
|
||||
self.decoder = Decoder(
|
||||
config.model_args.out_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.decoder_type,
|
||||
config.model_args.decoder_params,
|
||||
)
|
||||
self.duration_predictor = DurationPredictor(config.model_args.hidden_channels + config.model_args.d_vector_dim)
|
||||
|
||||
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
if c_in_channels > 0 and c_in_channels != hidden_channels:
|
||||
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
|
||||
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1)
|
||||
|
||||
@staticmethod
|
||||
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
||||
|
@ -244,7 +296,7 @@ class SpeedySpeech(TTSModel):
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, train_audio
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
@ -260,3 +312,8 @@ class SpeedySpeech(TTSModel):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import SpeedySpeechLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return SpeedySpeechLoss(self.config)
|
||||
|
|
|
@ -1,166 +1,86 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class Tacotron(TacotronAbstract):
|
||||
class Tacotron(BaseTacotron):
|
||||
"""Tacotron as in https://arxiv.org/abs/1703.10135
|
||||
|
||||
It's an autoregressive encoder-attention-decoder-postnet architecture.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of input characters to define the size of embedding layer.
|
||||
num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings.
|
||||
r (int): initial model reduction rate.
|
||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
|
||||
attn_win (bool, optional): enable/disable attention windowing.
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False.
|
||||
location_attn (bool, optional): enable/disable location sensitive attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to True.
|
||||
attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5.
|
||||
separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient
|
||||
flow from stopnet to the rest of the model. Defaults to True.
|
||||
bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False.
|
||||
double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False.
|
||||
ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None.
|
||||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
d_vector_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
|
||||
output frames to the prenet.
|
||||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
Check `TacotronConfig` for the arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
d_vector_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
gradual_training=None,
|
||||
max_decoder_steps=500,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
d_vector_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
gradual_training,
|
||||
)
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
|
||||
# speaker embedding layers
|
||||
self.num_chars, self.config = self.get_characters(config)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
d_vector_dim = 256
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, d_vector_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += d_vector_dim # add speaker embedding dim
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
self.decoder_output_dim,
|
||||
self.r,
|
||||
self.memory_size,
|
||||
self.attention_type,
|
||||
self.windowing,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
self.postnet = PostCBHG(self.decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
d_vector_dim=d_vector_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
|
@ -169,21 +89,21 @@ class Tacotron(TacotronAbstract):
|
|||
if self.double_decoder_consistency:
|
||||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
ddc_r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
self.decoder_output_dim,
|
||||
self.ddc_r,
|
||||
self.memory_size,
|
||||
self.attention_type,
|
||||
self.windowing,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
|
@ -205,7 +125,9 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
|
@ -341,7 +263,7 @@ class Tacotron(TacotronAbstract):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap, batch, outputs):
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
|
@ -362,7 +284,7 @@ class Tacotron(TacotronAbstract):
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_spectrogram(pred_spec.T)
|
||||
return figures, train_audio
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
return self.train_step(batch, criterion)
|
||||
|
|
|
@ -1,160 +1,84 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class Tacotron2(TacotronAbstract):
|
||||
class Tacotron2(BaseTacotron):
|
||||
"""Tacotron2 as in https://arxiv.org/abs/1712.05884
|
||||
|
||||
It's an autoregressive encoder-attention-decoder-postnet architecture.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of input characters to define the size of embedding layer.
|
||||
num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings.
|
||||
r (int): initial model reduction rate.
|
||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.tacotron.common_layers.init_attn```. Defaults to 'original'.
|
||||
attn_win (bool, optional): enable/disable attention windowing.
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False.
|
||||
location_attn (bool, optional): enable/disable location sensitive attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to True.
|
||||
attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5.
|
||||
separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient
|
||||
flow from stopnet to the rest of the model. Defaults to True.
|
||||
bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False.
|
||||
double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False.
|
||||
ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None.
|
||||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
d_vector_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
gradual_training (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
Check `TacotronConfig` for the arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
d_vector_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=None,
|
||||
max_decoder_steps=500,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
d_vector_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
gradual_training,
|
||||
)
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
d_vector_dim = 512
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, d_vector_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += d_vector_dim # add speaker embedding dim
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0)
|
||||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
self.r,
|
||||
self.attention_type,
|
||||
self.attention_win,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
self.postnet = Postnet(self.out_channels)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_g = prenet_dropout_at_inference
|
||||
self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst and use_gst:
|
||||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
d_vector_dim=d_vector_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
|
@ -165,19 +89,19 @@ class Tacotron2(TacotronAbstract):
|
|||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
ddc_r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
self.ddc_r,
|
||||
self.attention_type,
|
||||
self.attention_win,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -206,7 +130,9 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
@ -342,7 +268,7 @@ class Tacotron2(TacotronAbstract):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap, batch, outputs):
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
|
@ -363,7 +289,7 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, train_audio
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
return self.train_step(batch, criterion)
|
||||
|
|
|
@ -12,7 +12,7 @@ class Tacotron2(keras.models.Model):
|
|||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
out_channels=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
|
@ -31,7 +31,7 @@ class Tacotron2(keras.models.Model):
|
|||
super().__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.out_channels = out_channels
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.num_speakers = num_speakers
|
||||
self.speaker_embed_dim = 256
|
||||
|
@ -58,7 +58,7 @@ class Tacotron2(keras.models.Model):
|
|||
name="decoder",
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name="postnet")
|
||||
self.postnet = Postnet(out_channels, 5, name="postnet")
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from TTS.model import BaseModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVocoder(BaseModel):
|
||||
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
|
||||
|
||||
It defines `vocoder` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
Loading…
Reference in New Issue