From b3fb0e19e8fbc57636ffa014d22d62bbaa709c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 22 Apr 2022 12:39:46 +0200 Subject: [PATCH] Implement get_state_dict --- TTS/tts/models/forward_tts.py | 50 ++++++++++++++++----- TTS/tts/models/forward_tts_e2e.py | 73 ++++++++++++++++++++----------- 2 files changed, 88 insertions(+), 35 deletions(-) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index f7db282c..147093c5 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -95,6 +95,9 @@ class ForwardTTSArgs(Coqpit): num_speakers (int): Number of speakers for the speaker embedding layer. Defaults to 0. + use_speaker_embedding (bool): + Whether to use a speaker embedding layer. Defaults to False. + speakers_file (str): Path to the speaker mapping file for the Speaker Manager. Defaults to None. @@ -107,8 +110,10 @@ class ForwardTTSArgs(Coqpit): d_vector_dim (int): Number of d-vector channels. Defaults to 0. - """ + d_vector_file (str): + Path to the d-vector file. Defaults to None. + """ num_chars: int = None out_channels: int = 80 hidden_channels: int = 384 @@ -148,6 +153,7 @@ class ForwardTTSArgs(Coqpit): max_duration: int = 75 num_speakers: int = 1 use_speaker_embedding: bool = False + speaker_embedding_channels: int = 256 speakers_file: str = None use_d_vector_file: bool = False d_vector_dim: int = None @@ -177,9 +183,18 @@ class ForwardTTS(BaseTTS): Defaults to None. Examples: - >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs - >>> config = ForwardTTSArgs() - >>> model = ForwardTTS(config) + Instantiate the model directly. + + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs + >>> args = ForwardTTSE2eArgs() + >>> model = ForwardTTSE2e(args) + + Instantiate the model from config. + + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e + >>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig + >>> config = FastPitchE2eConfig(num_chars=10) + >>> model = ForwardTTSE2e.init_from_config(config) """ # pylint: disable=dangerous-default-value @@ -272,6 +287,7 @@ class ForwardTTS(BaseTTS): self._init_d_vector() def _init_speaker_embedding(self): + """Init class arguments for training with a speaker embedding layer.""" # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: print(" > initialization of speaker-embedding layers.") @@ -279,6 +295,7 @@ class ForwardTTS(BaseTTS): self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) def _init_d_vector(self): + """Init class arguments for training with external speaker embeddings.""" # pylint: disable=attribute-defined-outside-init if hasattr(self, "emb_g"): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") @@ -286,7 +303,7 @@ class ForwardTTS(BaseTTS): @staticmethod def _set_cond_input(aux_input: Dict): - """Set the speaker conditioning input based on the multi-speaker mode.""" + """Set auxilliary model inputs based on the model configuration.""" sid, g, lid = None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] @@ -305,12 +322,19 @@ class ForwardTTS(BaseTTS): return sid, g, lid def get_aux_input(self, aux_input: Dict): + """Get auxilliary model inputs based on the model configuration.""" sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} @staticmethod def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the durations. + """Generate an attention mask from the linear scale durations. + + Args: + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. Shapes - dr: :math:`(B, T_{en})` @@ -327,8 +351,14 @@ class ForwardTTS(BaseTTS): return attn def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs + """Generate attention alignment map from linear scale durations and + expand encoder outputs. + + Args: + en (Tensor): Encoder outputs. + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Shapes: - en: :math:`(B, D_{en}, T_{en})` @@ -360,8 +390,8 @@ class ForwardTTS(BaseTTS): 5. Round the duration values. Args: - o_dr_log: Log scale durations. - x_mask: Input text mask. + o_dr_log (Tensor): Log scale durations. + x_mask (Tensor): Input text mask. Shapes: - o_dr_log: :math:`(B, T_{de})` diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py index a330cb46..18b6d7c9 100644 --- a/TTS/tts/models/forward_tts_e2e.py +++ b/TTS/tts/models/forward_tts_e2e.py @@ -474,31 +474,6 @@ class ForwardTTSE2e(BaseTTSE2E): model_outputs = {**encoder_outputs} return model_outputs - @staticmethod - def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False): - """Initiate model from config - - Args: - config (ForwardTTSE2eConfig): Model config. - samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. - Defaults to None. - """ - from TTS.utils.audio.processor import AudioProcessor - - tokenizer, new_config = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config, samples) - # language_manager = LanguageManager.init_from_config(config) - return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager) - - def load_checkpoint( - self, config, checkpoint_path, eval=False - ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - if eval: - self.eval() - assert not self.training - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): if optimizer_idx == 0: tokens = batch["text_input"] @@ -1000,3 +975,51 @@ class ForwardTTSE2e(BaseTTSE2E): mel_fmax=self.config.audio.mel_fmax, mel_fmin=self.config.audio.mel_fmin, ) + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False): + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio.processor import AudioProcessor + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + # language_manager = LanguageManager.init_from_config(config) + return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = { + "config": self.config.to_dict(), + "args": self.args.to_dict(), + "model": self.state_dict + } + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_ids + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) + torch.save(save_state, checkpoint_path)