mirror of https://github.com/coqui-ai/TTS.git
Implement get_state_dict
This commit is contained in:
parent
ce4f96292a
commit
b3fb0e19e8
|
@ -95,6 +95,9 @@ class ForwardTTSArgs(Coqpit):
|
||||||
num_speakers (int):
|
num_speakers (int):
|
||||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
use_speaker_embedding (bool):
|
||||||
|
Whether to use a speaker embedding layer. Defaults to False.
|
||||||
|
|
||||||
speakers_file (str):
|
speakers_file (str):
|
||||||
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||||
|
|
||||||
|
@ -107,8 +110,10 @@ class ForwardTTSArgs(Coqpit):
|
||||||
d_vector_dim (int):
|
d_vector_dim (int):
|
||||||
Number of d-vector channels. Defaults to 0.
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
"""
|
d_vector_file (str):
|
||||||
|
Path to the d-vector file. Defaults to None.
|
||||||
|
|
||||||
|
"""
|
||||||
num_chars: int = None
|
num_chars: int = None
|
||||||
out_channels: int = 80
|
out_channels: int = 80
|
||||||
hidden_channels: int = 384
|
hidden_channels: int = 384
|
||||||
|
@ -148,6 +153,7 @@ class ForwardTTSArgs(Coqpit):
|
||||||
max_duration: int = 75
|
max_duration: int = 75
|
||||||
num_speakers: int = 1
|
num_speakers: int = 1
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
d_vector_dim: int = None
|
d_vector_dim: int = None
|
||||||
|
@ -177,9 +183,18 @@ class ForwardTTS(BaseTTS):
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
Instantiate the model directly.
|
||||||
>>> config = ForwardTTSArgs()
|
|
||||||
>>> model = ForwardTTS(config)
|
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs
|
||||||
|
>>> args = ForwardTTSE2eArgs()
|
||||||
|
>>> model = ForwardTTSE2e(args)
|
||||||
|
|
||||||
|
Instantiate the model from config.
|
||||||
|
|
||||||
|
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e
|
||||||
|
>>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig
|
||||||
|
>>> config = FastPitchE2eConfig(num_chars=10)
|
||||||
|
>>> model = ForwardTTSE2e.init_from_config(config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
@ -272,6 +287,7 @@ class ForwardTTS(BaseTTS):
|
||||||
self._init_d_vector()
|
self._init_d_vector()
|
||||||
|
|
||||||
def _init_speaker_embedding(self):
|
def _init_speaker_embedding(self):
|
||||||
|
"""Init class arguments for training with a speaker embedding layer."""
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if self.num_speakers > 0:
|
if self.num_speakers > 0:
|
||||||
print(" > initialization of speaker-embedding layers.")
|
print(" > initialization of speaker-embedding layers.")
|
||||||
|
@ -279,6 +295,7 @@ class ForwardTTS(BaseTTS):
|
||||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
|
||||||
def _init_d_vector(self):
|
def _init_d_vector(self):
|
||||||
|
"""Init class arguments for training with external speaker embeddings."""
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||||
|
@ -286,7 +303,7 @@ class ForwardTTS(BaseTTS):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
"""Set auxilliary model inputs based on the model configuration."""
|
||||||
sid, g, lid = None, None, None
|
sid, g, lid = None, None, None
|
||||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
sid = aux_input["speaker_ids"]
|
sid = aux_input["speaker_ids"]
|
||||||
|
@ -305,12 +322,19 @@ class ForwardTTS(BaseTTS):
|
||||||
return sid, g, lid
|
return sid, g, lid
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
|
"""Get auxilliary model inputs based on the model configuration."""
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_attn(dr, x_mask, y_mask=None):
|
def generate_attn(dr, x_mask, y_mask=None):
|
||||||
"""Generate an attention mask from the durations.
|
"""Generate an attention mask from the linear scale durations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dr (Tensor): Linear scale durations.
|
||||||
|
x_mask (Tensor): Mask for the input (character) sequence.
|
||||||
|
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||||
|
if None. Defaults to None.
|
||||||
|
|
||||||
Shapes
|
Shapes
|
||||||
- dr: :math:`(B, T_{en})`
|
- dr: :math:`(B, T_{en})`
|
||||||
|
@ -327,8 +351,14 @@ class ForwardTTS(BaseTTS):
|
||||||
return attn
|
return attn
|
||||||
|
|
||||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||||
"""Generate attention alignment map from durations and
|
"""Generate attention alignment map from linear scale durations and
|
||||||
expand encoder outputs
|
expand encoder outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
en (Tensor): Encoder outputs.
|
||||||
|
dr (Tensor): Linear scale durations.
|
||||||
|
x_mask (Tensor): Mask for the input (character) sequence.
|
||||||
|
y_mask (Tensor): Mask for the output (spectrogram) sequence.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- en: :math:`(B, D_{en}, T_{en})`
|
- en: :math:`(B, D_{en}, T_{en})`
|
||||||
|
@ -360,8 +390,8 @@ class ForwardTTS(BaseTTS):
|
||||||
5. Round the duration values.
|
5. Round the duration values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
o_dr_log: Log scale durations.
|
o_dr_log (Tensor): Log scale durations.
|
||||||
x_mask: Input text mask.
|
x_mask (Tensor): Input text mask.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- o_dr_log: :math:`(B, T_{de})`
|
- o_dr_log: :math:`(B, T_{de})`
|
||||||
|
|
|
@ -474,31 +474,6 @@ class ForwardTTSE2e(BaseTTSE2E):
|
||||||
model_outputs = {**encoder_outputs}
|
model_outputs = {**encoder_outputs}
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
|
||||||
"""Initiate model from config
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (ForwardTTSE2eConfig): Model config.
|
|
||||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
from TTS.utils.audio.processor import AudioProcessor
|
|
||||||
|
|
||||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
|
||||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
|
||||||
# language_manager = LanguageManager.init_from_config(config)
|
|
||||||
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
|
||||||
|
|
||||||
def load_checkpoint(
|
|
||||||
self, config, checkpoint_path, eval=False
|
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
|
||||||
self.load_state_dict(state["model"])
|
|
||||||
if eval:
|
|
||||||
self.eval()
|
|
||||||
assert not self.training
|
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
tokens = batch["text_input"]
|
tokens = batch["text_input"]
|
||||||
|
@ -1000,3 +975,51 @@ class ForwardTTSE2e(BaseTTSE2E):
|
||||||
mel_fmax=self.config.audio.mel_fmax,
|
mel_fmax=self.config.audio.mel_fmax,
|
||||||
mel_fmin=self.config.audio.mel_fmin,
|
mel_fmin=self.config.audio.mel_fmin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
||||||
|
"""Initiate model from config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (ForwardTTSE2eConfig): Model config.
|
||||||
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
|
||||||
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||||
|
# language_manager = LanguageManager.init_from_config(config)
|
||||||
|
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
):
|
||||||
|
"""Load model from a checkpoint created by the 👟"""
|
||||||
|
# pylint: disable=unused-argument, redefined-builtin
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
||||||
|
def get_state_dict(self):
|
||||||
|
"""Custom state dict of the model with all the necessary components for inference."""
|
||||||
|
save_state = {
|
||||||
|
"config": self.config.to_dict(),
|
||||||
|
"args": self.args.to_dict(),
|
||||||
|
"model": self.state_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasattr(self, "emb_g"):
|
||||||
|
save_state["speaker_ids"] = self.speaker_manager.speaker_ids
|
||||||
|
|
||||||
|
if self.args.use_d_vector_file:
|
||||||
|
# TODO: implement saving of d_vectors
|
||||||
|
...
|
||||||
|
return save_state
|
||||||
|
|
||||||
|
def save(self, config, checkpoint_path):
|
||||||
|
"""Save model to a file."""
|
||||||
|
save_state = self.get_state_dict(config, checkpoint_path)
|
||||||
|
torch.save(save_state, checkpoint_path)
|
||||||
|
|
Loading…
Reference in New Issue