mirror of https://github.com/coqui-ai/TTS.git
Comment BaseTacotron and remove unused funcs
This commit is contained in:
parent
aa25f70b95
commit
330ee7d208
|
@ -9,15 +9,15 @@ from torch import nn
|
|||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
class BaseTacotron(BaseTTS):
|
||||
"""Base class shared by Tacotron and Tacotron2"""
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__(config)
|
||||
|
||||
# pass all config fields as class attributes
|
||||
|
@ -45,6 +45,7 @@ class BaseTacotron(BaseTTS):
|
|||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
"""Set missing fields to their default values"""
|
||||
if aux_input:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
return None
|
||||
|
@ -53,14 +54,12 @@ class BaseTacotron(BaseTTS):
|
|||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
"""Init the backward decoder for Forward-Backward decoding."""
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
"""Init the coarse decoder for Double-Decoder Consistency."""
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
@ -80,6 +79,13 @@ class BaseTacotron(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load model checkpoint and set up internals.
|
||||
|
||||
Args:
|
||||
config (Coqpi): model configuration.
|
||||
checkpoint_path (str): path to checkpoint file.
|
||||
eval (bool): whether to load model for evaluation.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# TODO: set r in run-time by taking it from the new config
|
||||
|
@ -98,45 +104,9 @@ class BaseTacotron(BaseTTS):
|
|||
assert not self.training
|
||||
|
||||
def get_criterion(self) -> nn.Module:
|
||||
"""Get the model criterion used in training."""
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Compute Tacotron's auxiliary inputs based on model config.
|
||||
- speaker d_vector
|
||||
- style wav for GST
|
||||
- speaker ID for speaker embedding
|
||||
"""
|
||||
# setup speaker_id
|
||||
if self.config.use_speaker_embedding:
|
||||
speaker_id = kwargs.get("speaker_id", 0)
|
||||
else:
|
||||
speaker_id = None
|
||||
# setup d_vector
|
||||
d_vector = (
|
||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
||||
if self.config.use_d_vector_file and self.config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
# setup style_mel
|
||||
if "style_wav" in kwargs:
|
||||
style_wav = kwargs["style_wav"]
|
||||
elif self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
else:
|
||||
style_wav = None
|
||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return aux_inputs
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
@ -182,15 +152,6 @@ class BaseTacotron(BaseTTS):
|
|||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
"""Compute speaker embedding vectors"""
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
|
@ -242,4 +203,4 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
|
|
Loading…
Reference in New Issue