mirror of https://github.com/coqui-ai/TTS.git
Comment BaseTacotron and remove unused funcs
This commit is contained in:
parent
aa25f70b95
commit
330ee7d208
|
@ -9,15 +9,15 @@ from torch import nn
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
|
|
||||||
class BaseTacotron(BaseTTS):
|
class BaseTacotron(BaseTTS):
|
||||||
|
"""Base class shared by Tacotron and Tacotron2"""
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
"""Abstract Tacotron class"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# pass all config fields as class attributes
|
# pass all config fields as class attributes
|
||||||
|
@ -45,6 +45,7 @@ class BaseTacotron(BaseTTS):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||||
|
"""Set missing fields to their default values"""
|
||||||
if aux_input:
|
if aux_input:
|
||||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||||
return None
|
return None
|
||||||
|
@ -53,14 +54,12 @@ class BaseTacotron(BaseTTS):
|
||||||
# INIT FUNCTIONS
|
# INIT FUNCTIONS
|
||||||
#############################
|
#############################
|
||||||
|
|
||||||
def _init_states(self):
|
|
||||||
self.embedded_speakers = None
|
|
||||||
self.embedded_speakers_projected = None
|
|
||||||
|
|
||||||
def _init_backward_decoder(self):
|
def _init_backward_decoder(self):
|
||||||
|
"""Init the backward decoder for Forward-Backward decoding."""
|
||||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||||
|
|
||||||
def _init_coarse_decoder(self):
|
def _init_coarse_decoder(self):
|
||||||
|
"""Init the coarse decoder for Double-Decoder Consistency."""
|
||||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||||
self.coarse_decoder.r_init = self.ddc_r
|
self.coarse_decoder.r_init = self.ddc_r
|
||||||
self.coarse_decoder.set_r(self.ddc_r)
|
self.coarse_decoder.set_r(self.ddc_r)
|
||||||
|
@ -80,6 +79,13 @@ class BaseTacotron(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
"""Load model checkpoint and set up internals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpi): model configuration.
|
||||||
|
checkpoint_path (str): path to checkpoint file.
|
||||||
|
eval (bool): whether to load model for evaluation.
|
||||||
|
"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
# TODO: set r in run-time by taking it from the new config
|
# TODO: set r in run-time by taking it from the new config
|
||||||
|
@ -98,45 +104,9 @@ class BaseTacotron(BaseTTS):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
def get_criterion(self) -> nn.Module:
|
def get_criterion(self) -> nn.Module:
|
||||||
|
"""Get the model criterion used in training."""
|
||||||
return TacotronLoss(self.config)
|
return TacotronLoss(self.config)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
|
||||||
|
|
||||||
def get_aux_input(self, **kwargs) -> Dict:
|
|
||||||
"""Compute Tacotron's auxiliary inputs based on model config.
|
|
||||||
- speaker d_vector
|
|
||||||
- style wav for GST
|
|
||||||
- speaker ID for speaker embedding
|
|
||||||
"""
|
|
||||||
# setup speaker_id
|
|
||||||
if self.config.use_speaker_embedding:
|
|
||||||
speaker_id = kwargs.get("speaker_id", 0)
|
|
||||||
else:
|
|
||||||
speaker_id = None
|
|
||||||
# setup d_vector
|
|
||||||
d_vector = (
|
|
||||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
|
||||||
if self.config.use_d_vector_file and self.config.use_speaker_embedding
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
# setup style_mel
|
|
||||||
if "style_wav" in kwargs:
|
|
||||||
style_wav = kwargs["style_wav"]
|
|
||||||
elif self.config.has("gst_style_input"):
|
|
||||||
style_wav = self.config.gst_style_input
|
|
||||||
else:
|
|
||||||
style_wav = None
|
|
||||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
|
||||||
# inicialize GST with zero dict.
|
|
||||||
style_wav = {}
|
|
||||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
|
||||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
|
||||||
style_wav[str(i)] = 0
|
|
||||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
|
||||||
return aux_inputs
|
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# COMMON COMPUTE FUNCTIONS
|
# COMMON COMPUTE FUNCTIONS
|
||||||
#############################
|
#############################
|
||||||
|
@ -182,15 +152,6 @@ class BaseTacotron(BaseTTS):
|
||||||
# EMBEDDING FUNCTIONS
|
# EMBEDDING FUNCTIONS
|
||||||
#############################
|
#############################
|
||||||
|
|
||||||
def compute_speaker_embedding(self, speaker_ids):
|
|
||||||
"""Compute speaker embedding vectors"""
|
|
||||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
|
||||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
|
||||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
|
||||||
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
|
||||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
|
||||||
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
|
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||||
"""Compute global style token"""
|
"""Compute global style token"""
|
||||||
if isinstance(style_input, dict):
|
if isinstance(style_input, dict):
|
||||||
|
|
Loading…
Reference in New Issue