diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index c661c4cc..d0cc81cc 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -9,15 +9,15 @@ from torch import nn from TTS.tts.layers.losses import TacotronLoss from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import sequence_mask -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler class BaseTacotron(BaseTTS): + """Base class shared by Tacotron and Tacotron2""" + def __init__(self, config: Coqpit): - """Abstract Tacotron class""" super().__init__(config) # pass all config fields as class attributes @@ -45,6 +45,7 @@ class BaseTacotron(BaseTTS): @staticmethod def _format_aux_input(aux_input: Dict) -> Dict: + """Set missing fields to their default values""" if aux_input: return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) return None @@ -53,14 +54,12 @@ class BaseTacotron(BaseTTS): # INIT FUNCTIONS ############################# - def _init_states(self): - self.embedded_speakers = None - self.embedded_speakers_projected = None - def _init_backward_decoder(self): + """Init the backward decoder for Forward-Backward decoding.""" self.decoder_backward = copy.deepcopy(self.decoder) def _init_coarse_decoder(self): + """Init the coarse decoder for Double-Decoder Consistency.""" self.coarse_decoder = copy.deepcopy(self.decoder) self.coarse_decoder.r_init = self.ddc_r self.coarse_decoder.set_r(self.ddc_r) @@ -80,6 +79,13 @@ class BaseTacotron(BaseTTS): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin + """Load model checkpoint and set up internals. + + Args: + config (Coqpi): model configuration. + checkpoint_path (str): path to checkpoint file. + eval (bool): whether to load model for evaluation. + """ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) # TODO: set r in run-time by taking it from the new config @@ -98,45 +104,9 @@ class BaseTacotron(BaseTTS): assert not self.training def get_criterion(self) -> nn.Module: + """Get the model criterion used in training.""" return TacotronLoss(self.config) - @staticmethod - def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: - return get_speaker_manager(config, restore_path, data, out_path) - - def get_aux_input(self, **kwargs) -> Dict: - """Compute Tacotron's auxiliary inputs based on model config. - - speaker d_vector - - style wav for GST - - speaker ID for speaker embedding - """ - # setup speaker_id - if self.config.use_speaker_embedding: - speaker_id = kwargs.get("speaker_id", 0) - else: - speaker_id = None - # setup d_vector - d_vector = ( - self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0]) - if self.config.use_d_vector_file and self.config.use_speaker_embedding - else None - ) - # setup style_mel - if "style_wav" in kwargs: - style_wav = kwargs["style_wav"] - elif self.config.has("gst_style_input"): - style_wav = self.config.gst_style_input - else: - style_wav = None - if style_wav is None and "use_gst" in self.config and self.config.use_gst: - # inicialize GST with zero dict. - style_wav = {} - print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(self.config.gst["gst_num_style_tokens"]): - style_wav[str(i)] = 0 - aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} - return aux_inputs - ############################# # COMMON COMPUTE FUNCTIONS ############################# @@ -182,15 +152,6 @@ class BaseTacotron(BaseTTS): # EMBEDDING FUNCTIONS ############################# - def compute_speaker_embedding(self, speaker_ids): - """Compute speaker embedding vectors""" - if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1) - if hasattr(self, "speaker_project_mel") and speaker_ids is not None: - self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1) - def compute_gst(self, inputs, style_input, speaker_embedding=None): """Compute global style token""" if isinstance(style_input, dict): @@ -242,4 +203,4 @@ class BaseTacotron(BaseTTS): self.decoder.set_r(r) if trainer.config.bidirectional_decoder: trainer.model.decoder_backward.set_r(r) - print(f"\n > Number of output frames: {self.decoder.r}") \ No newline at end of file + print(f"\n > Number of output frames: {self.decoder.r}")