coqui-tts/TTS/tts/models/base_tacotron.py

286 lines
11 KiB
Python

import copy
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List
import torch
from coqpit import MISSING, Coqpit
from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@dataclass
class BaseTacotronArgs(Coqpit):
"""TODO: update Tacotron configs using it"""
num_chars: int = MISSING
num_speakers: int = MISSING
r: int = MISSING
out_channels: int = 80
decoder_output_dim: int = 80
attn_type: str = "original"
attn_win: bool = False
attn_norm: str = "softmax"
prenet_type: str = "original"
prenet_dropout: bool = True
prenet_dropout_at_inference: bool = False
forward_attn: bool = False
trans_agent: bool = False
forward_attn_mask: bool = False
location_attn: bool = True
attn_K: int = 5
separate_stopnet: bool = True
bidirectional_decoder: bool = False
double_decoder_consistency: bool = False
ddc_r: int = None
encoder_in_features: int = 512
decoder_in_features: int = 512
d_vector_dim: int = None
use_gst: bool = False
gst: bool = None
gradual_training: bool = None
class BaseTacotron(BaseTTS):
def __init__(self, config: Coqpit):
"""Abstract Tacotron class"""
super().__init__()
for key in config:
setattr(self, key, config[key])
# layers
self.embedding = None
self.encoder = None
self.decoder = None
self.postnet = None
# init tensors
self.embedded_speakers = None
self.embedded_speakers_projected = None
# global style token
if self.gst and self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
#############################
# INIT FUNCTIONS
#############################
def _init_states(self):
self.embedded_speakers = None
self.embedded_speakers_projected = None
def _init_backward_decoder(self):
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
#############################
# CORE FUNCTIONS
#############################
@abstractmethod
def forward(self):
pass
@abstractmethod
def inference(self):
pass
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if "r" in state:
self.decoder.set_r(state["r"])
else:
self.decoder.set_r(state["config"]["r"])
if eval:
self.eval()
assert not self.training
def get_criterion(self) -> nn.Module:
return TacotronLoss(self.config)
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
model_characters = phonemes if config.use_phonemes else symbols
return model_characters, config
@staticmethod
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def get_aux_input(self, **kwargs) -> Dict:
"""Compute Tacotron's auxiliary inputs based on model config.
- speaker d_vector
- style wav for GST
- speaker ID for speaker embedding
"""
# setup speaker_id
if self.config.use_speaker_embedding:
speaker_id = kwargs.get("speaker_id", 0)
else:
speaker_id = None
# setup d_vector
d_vector = (
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
if self.config.use_d_vector_file and self.config.use_speaker_embedding
else None
)
# setup style_mel
if "style_wav" in kwargs:
style_wav = kwargs["style_wav"]
elif self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
return aux_inputs
#############################
# COMMON COMPUTE FUNCTIONS
#############################
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
input_mask = sequence_mask(text_lengths)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
"""Run backwards decoder"""
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
"""Double Decoder Consistency"""
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
#############################
# EMBEDDING FUNCTIONS
#############################
def compute_speaker_embedding(self, speaker_ids):
"""Compute speaker embedding vectors"""
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
@staticmethod
def _add_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = outputs + embedded_speakers_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
return outputs
#############################
# CALLBACKS
#############################
def on_epoch_start(self, trainer):
"""Callback for setting values wrt gradual training schedule.
Args:
trainer (TrainerTTS): TTS trainer object that is used to train this model.
"""
if self.gradual_training:
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
trainer.config.r = r
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")