Comment BaseTacotron and remove unused funcs

This commit is contained in:
Eren Gölge 2021-10-20 18:17:25 +00:00
parent aa25f70b95
commit 330ee7d208
1 changed files with 14 additions and 53 deletions

View File

@ -9,15 +9,15 @@ from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
def __init__(self, config: Coqpit):
"""Abstract Tacotron class"""
super().__init__(config)
# pass all config fields as class attributes
@ -45,6 +45,7 @@ class BaseTacotron(BaseTTS):
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
"""Set missing fields to their default values"""
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
return None
@ -53,14 +54,12 @@ class BaseTacotron(BaseTTS):
# INIT FUNCTIONS
#############################
def _init_states(self):
self.embedded_speakers = None
self.embedded_speakers_projected = None
def _init_backward_decoder(self):
"""Init the backward decoder for Forward-Backward decoding."""
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
"""Init the coarse decoder for Double-Decoder Consistency."""
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
@ -80,6 +79,13 @@ class BaseTacotron(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals.
Args:
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool): whether to load model for evaluation.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
@ -98,45 +104,9 @@ class BaseTacotron(BaseTTS):
assert not self.training
def get_criterion(self) -> nn.Module:
"""Get the model criterion used in training."""
return TacotronLoss(self.config)
@staticmethod
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def get_aux_input(self, **kwargs) -> Dict:
"""Compute Tacotron's auxiliary inputs based on model config.
- speaker d_vector
- style wav for GST
- speaker ID for speaker embedding
"""
# setup speaker_id
if self.config.use_speaker_embedding:
speaker_id = kwargs.get("speaker_id", 0)
else:
speaker_id = None
# setup d_vector
d_vector = (
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
if self.config.use_d_vector_file and self.config.use_speaker_embedding
else None
)
# setup style_mel
if "style_wav" in kwargs:
style_wav = kwargs["style_wav"]
elif self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
return aux_inputs
#############################
# COMMON COMPUTE FUNCTIONS
#############################
@ -182,15 +152,6 @@ class BaseTacotron(BaseTTS):
# EMBEDDING FUNCTIONS
#############################
def compute_speaker_embedding(self, speaker_ids):
"""Compute speaker embedding vectors"""
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
@ -242,4 +203,4 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")
print(f"\n > Number of output frames: {self.decoder.r}")